Code Monkey home page Code Monkey logo

gordicaleksa / pytorch-gat Goto Github PK

View Code? Open in Web Editor NEW
2.3K 46.0 311.0 25.79 MB

My implementation of the original GAT paper (Veličković et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!

Home Page: https://youtube.com/c/TheAIEpiphany

License: MIT License

Python 2.46% Jupyter Notebook 97.54%
gat graph-attention-networks attention-mechanism self-attention pytorch python attention pytorch-gat gat-tutorial deep-learning graph-attention-network pytorch-implementation jupyter

pytorch-gat's Introduction

GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️

This repo contains a PyTorch implementation of the original GAT paper (:link: Veličković et al.).
It's aimed at making it easy to start playing and learning about GAT and GNNs in general.

Table of Contents

What are GNNs?

Graph neural networks are a family of neural networks that are dealing with signals defined over graphs!

Graphs can model many interesting natural phenomena, so you'll see them used everywhere from:

and all the way to particle physics at Large Hedron Collider (LHC), fake news detection and the list goes on and on!

GAT is a representative of spatial (convolutional) GNNs. Since CNNs had a tremendous success in the field of computer vision, researchers decided to generalize it to graphs and so here we are! 🤓

Here is a schematic of GAT's structure:

Cora visualized

You can't just start talking about GNNs without mentioning the single most famous graph dataset - Cora.

Nodes in Cora represent research papers and the links are, you guessed it, citations between those papers.

I've added a utility for visualizing Cora and doing basic network analysis. Here is how Cora looks like:

Node size corresponds to its degree (i.e. the number of in/outgoing edges). Edge thickness roughly corresponds to how "popular" or "connected" that edge is (edge betweennesses is the nerdy term check out the code.)

And here is a plot showing the degree distribution on Cora:

In and out degree plots are the same since we're dealing with an undirected graph.

On the bottom plot (degree distribution) you can see an interesting peak happening in the [2, 4] range. This means that the majority of nodes have a small number of edges but there is 1 node that has 169 edges! (the big green node)

Attention visualized

Once we have a fully-trained GAT model we can visualize the attention that certain "nodes" have learned.
Nodes use attention to decide how to aggregate their neighborhood, enough talk, let's see it:

This is one of Cora's nodes that has the most edges (citations). The colors represent the nodes of the same class. You can clearly see 2 things from this plot:

  • The graph is homophilic meaning similar nodes (nodes with same class) tend to cluster together.
  • Edge thickness on this chart is a function of attention, and since they are all of the same thickness, GAT basically learned to do something similar to GCN!

Similar rules hold for smaller neighborhoods. Also notice the self edges:

On the other hand PPI is learning much more interesting attention patterns:

On the left we can see that 6 neighbors are receiving a non-negligible amount of attention and on the right we can see that all of the attention is focused onto a single neighbor.

Finally 2 more interesting patterns - a strong self edge on the left and on the right we can see that a single neighbor is receiving a bulk of attention whereas the rest is equally distributed across the rest of the neighborhood:

Important note: all of the PPI visualizations are only possible for the first GAT layer. For some reason the attention coefficients for the second and third layers are almost all 0s (even though I achieved the published results).

Entropy histograms

Another way to understand that GAT isn't learning interesting attention patterns on Cora (i.e. that it's learning const attention) is by treating the node neighborhood's attention weights as a probability distribution, calculating the entropy, and accumulating the info across every node's neighborhood.

We'd love GAT's attention distributions to be skewed. You can see in orange how the histogram looks like for ideal uniform distributions, and you can see in light blue the learned distributions - they are exactly the same!

I've plotted only a single attention head from the first layer (out of 8) because they're all the same!

On the other hand PPI is learning much more interesting attention patterns:

As expected, the uniform distribution entropy histogram lies to the right (orange) since uniform distributions have the highest entropy.

Analyzing Cora's embedding space (t-SNE)

Ok, we've seen attention! What else is there to visualize? Well, let's visualize the learned embeddings from GAT's last layer. The output of GAT is a tensor of shape = (2708, 7) where 2708 is the number of nodes in Cora and 7 is the number of classes. Once we project those 7-dim vectors into 2D, using t-SNE, we get this:

We can see that the nodes with the same label/class are roughly clustered together - with these representations it's easy to train a simple classifier on top that will tell us which class the node belongs to.

Note: I've tried UMAP as well but didn't get nicer results + it has a lot of dependencies if you want to use their plot util.

Setup

So we talked about what GNNs are, and what they can do for you (among other things).
Let's get this thing running! Follow the next steps:

  1. git clone https://github.com/gordicaleksa/pytorch-GAT
  2. Open Anaconda console and navigate into project directory cd path_to_repo
  3. Run conda env create from project directory (this will create a brand new conda environment).
  4. Run activate pytorch-gat (for running scripts from your console or setup the interpreter in your IDE)

That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.


PyTorch pip package will come bundled with some version of CUDA/cuDNN with it, but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system. Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system.

Usage

Option 1: Jupyter Notebook

Just run jupyter notebook from you Anaconda console and it will open up a session in your default browser.
Open The Annotated GAT.ipynb and you're ready to play!


Note: if you get DLL load failed while importing win32api: The specified module could not be found
Just do pip uninstall pywin32 and then either pip install pywin32 or conda install pywin32 should fix it!

Option 2: Use your IDE of choice

You just need to link the Python environment you created in the setup section.

Training GAT

FYI, my GAT implementation achieves the published results:

  • On Cora I get the 82-83% accuracy on test nodes
  • On PPI I achieved the 0.973 micro-F1 score (and actually even higher)

Everything needed to train GAT on Cora is already setup. To run it (from console) just call:
python training_script_cora.py

You could also potentially:

  • add the --should_visualize - to visualize your graph data
  • add the --should_test - to evaluate GAT on the test portion of the data
  • add the --enable_tensorboard - to start saving metrics (accuracy, loss)

The code is well commented so you can (hopefully) understand how the training itself works.

The script will:

  • Dump checkpoint *.pth models into models/checkpoints/
  • Dump the final *.pth model into models/binaries/
  • Save metrics into runs/, just run tensorboard --logdir=runs from your Anaconda to visualize it
  • Periodically write some training metadata to the console

Same goes for training on PPI, just run python training_script_ppi.py. PPI is much more GPU-hungry so if you don't have a strong GPU with at least 8 GBs you'll need to add the --force_cpu flag to train GAT on CPU. You can alternatively try reducing the batch size to 1 or making the model slimmer.

You can visualize the metrics during the training, by calling tensorboard --logdir=runs from your console and pasting the http://localhost:6006/ URL into your browser:

Note: Cora's train split seems to be much harder than the validation and test splits looking at the loss and accuracy metrics.

Having said that most of the fun actually lies in the playground.py script.

Tip for understanding the code

I've added 3 GAT implementations - some are conceptually easier to understand some are more efficient. The most interesting and hardest one to understand is implementation 3. Implementation 1 and implementation 2 differ in subtle details but basically do the same thing.

Advice on how to approach the code:

  • Understand the implementation #2 first
  • Check out the differences it has compared to implementation #1
  • Finally, tackle the implementation #3

Profiling GAT

If you want to profile the 3 implementations just set the the playground_fn variable to PLAYGROUND.PROFILE_GAT in playground.py.

There are 2 params you may care about:

  • store_cache - set to True if you wish to save the memory/time profiling results after you've run it
  • skip_if_profiling_info_cached - set to True if you want to pull the profiling info from cache

The results will get stored in data/ in memory.dict and timing.dict dictionaries (pickle).

Note: implementation #3 is by far the most optimized one - you can see the details in the code.


I've also added profile_sparse_matrix_formats if you want to get some familiarity with different matrix sparse formats like COO, CSR, CSC, LIL, etc.

Visualization tools

If you want to visualize t-SNE embeddings, attention or embeddings set the playground_fn variable to PLAYGROUND.VISUALIZE_GAT and set the visualization_type to:

  • VisualizationType.ATTENTION - if you wish to visualize attention across node neighborhoods
  • VisualizationType.EMBEDDING - if you wish to visualize the embeddings (via t-SNE)
  • VisualizationType.ENTROPY - if you wish to visualize the entropy histograms

And you'll get crazy visualizations like these ones (VisualizationType.ATTENTION option):

On the left you can see the node with the highest degree in the whole Cora dataset.

If you're wondering about why these look like a circle, it's because I've used the layout_reingold_tilford_circular layout which is particularly well suited for tree like graphs (since we're visualizing a node and its neighbors this subgraph is effectively a m-ary tree).

But you can also use different drawing algorithms like kamada kawai (on the right), etc.

Feel free to go through the code and play with plotting attention from different GAT layers, plotting different node neighborhoods or attention heads. You can also easily change the number of layers in your GAT, although shallow GNNs tend to perform the best on small-world, homophilic graph datasets.


If you want to visualize Cora/PPI just set the playground_fn to PLAYGROUND.VISUALIZE_DATASET and you'll get the results from this README.

Hardware requirements

HW requirements are highly dependent on the graph data you'll use. If you just want to play with Cora, you're good to go with a 2+ GBs GPU.

It takes (on Cora citation network):

  • ~10 seconds to train GAT on my RTX 2080 GPU
  • 1.5 GBs of VRAM memory is reserved (PyTorch's caching overhead - far less is allocated for the actual tensors)
  • The model itself has only 365 KBs!

Compare this to hardware needed even for the smallest of transformers!

On the other hand the PPI dataset is much more GPU-hungry. You'll need a GPU with 8+ GBs of VRAM, or you can reduce the batch size to 1 and make the model "slimmer" and thus try to reduce the VRAM consumption.

Future todos:

  • Figure out why are the attention coefficients equal to 0 (for the PPI dataset, second and third layer)
  • Potentially add an implementation leveraging PyTorch's sparse API

If you have an idea of how to implement GAT using PyTorch's sparse API please feel free to submit a PR. I personally had difficulties with their API, it's in beta, and it's questionable whether it's at all possible to make an implementation as efficient as my implementation 3 using it.

Secondly, I'm still not sure why is GAT achieving reported results on PPI while there are some obvious numeric problems in deeper layers as manifested by all attention coefficients being equal to 0.

Learning material

If you're having difficulties understanding GAT I did an in-depth overview of the paper in this video:

The GAT paper explained

I also made a walk-through video of this repo (focusing on the potential pain points), and a blog for getting started with Graph ML in general! ❤️

I have some more videos which could further help you understand GNNs:

Acknowledgements

I found these repos useful (while developing this one):

Citation

If you find this code useful, please cite the following:

@misc{Gordić2020PyTorchGAT,
  author = {Gordić, Aleksa},
  title = {pytorch-GAT},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/pytorch-GAT}},
}

Licence

License: MIT

pytorch-gat's People

Contributors

gordicaleksa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-gat's Issues

A puzzling question about the implementation of ppi

Thank you @gordicaleksa for the tutorial and implementation of GAT, which made me understand it smoothly and walk through the codes sucessfully !

Here I still have a question when I run the codes of PPI dataset. In the code you have mentioned that skip connection is important and keep it otherwise the micro-F1 is almost 0, and I am wondering that the reason of the appearance of 0 is that, something leads some values(such as loss, feature etc.) to be nan during training?

More details:
The framework I used to train is paddlepaddle instead of pytorch. When I ran my code, I found that the value of loss would be nan finally, making the micro-F1 be 0. This strange phenomenon occurs whether I use skip connection or not. Furthemore, when I set skip connection to False, the value of micro-F1 first rises normally to about 0.7, then suddenly becomes 0, however, when I set skip connection to True, the value micro-F1 also rises normally first, to around 0.9, and suddenly becomes 0.
When I checked it, I found the appearance of NaN during training. This problem has always puzzled me, so I wonder that is there the same issue during your implemrntation?

Question about the order of concat and activation

In the original paper, I saw that the concat operation is after the activation. However, in your implementation, the order is reversed. Is there some reasons to change the order, or the order influences the results little. Thank you for your attention.

Query regarding visualization of attention

Thank you @gordicaleksa for the fantastic code and detailed documentation! It has helped me a lot in understanding the details of GAT.
While looking at the visualization functions in the code - I understand that entropy is used because the softmax applied over the attention coefficients bring it into a range of [0, 1] - resembling a probability distribution. While obtaining the attention coefficients from the GAT layer in the code, you have used:

def visualize_entropy_histograms(model_name=r'gat_PPI_000000.pth', dataset_name=DatasetType.PPI.name):
    # Fetch the data we'll need to create visualizations
    all_nodes_unnormalized_scores, edge_index, node_labels, gat = gat_forward_pass(model_name, dataset_name)

all_nodes_unnormalized_scores comes from the GAT forward function:

out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features)
return (out_nodes_features, edge_index)

When reading the GAT paper (Petar Veliˇckovi ́c et al) - the attention coefficients obtained after softmax are used to obtain the final output node features from the GAT layer. In the GAT implementation:

attentions_per_edge = self.neighborhood_aware_softmax(scores_per_edge, edge_index[self.trg_nodes_dim], num_of_nodes)

the above function gives the attention coefficients in [0, 1] range. The subsequent functions (self.aggregate_neighbors and self.skip_concat_bias) will give the final node features from the GAT layer. So is the "all_nodes_unnormalized_scores" variable used in the entropy histogram visualization function still in the range [0, 1]? Or is the entropy histogram used to visualize the output node features and not the softmax-normalized attention coefficients?

I also came across the entropy visualization in a DGL tutorial on GAT (https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html) and they were using the attention coefficients after softmax normalization for the visualization. Sorry if the question is very naive - I'm trying to apply this visualization to one of my projects involving inductive learning. Let me know if I have misunderstood the information being extracted from the GAT layer. Thanks in advance!

error related to install pycairo

Hi, @gordicaleksa,

Thanks for providing such great tutorial. I came across pycairo error when building the virtual env based on the provided environment.yml. The error is as follows:

Pip subprocess error:
  ERROR: Command errored out with exit status 1:
   command: /root/anaconda3/envs/pytorch-gat/bin/python /root/anaconda3/envs/pytorch-gat/lib/python3.8/site-packages/pip/_vendor/pep517/_in_process.py build_wheel /tmp/tmp1xff5oe8
       cwd: /tmp/pip-install-kkpb3rbl/pycairo_24f3dd5c18ec42d9854417db277fda61
  Complete output (15 lines):
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-3.8
  creating build/lib.linux-x86_64-3.8/cairo
  copying cairo/__init__.py -> build/lib.linux-x86_64-3.8/cairo
  copying cairo/__init__.pyi -> build/lib.linux-x86_64-3.8/cairo
  copying cairo/py.typed -> build/lib.linux-x86_64-3.8/cairo
  running build_ext
  Package cairo was not found in the pkg-config search path.
  Perhaps you should add the directory containing `cairo.pc'
  to the PKG_CONFIG_PATH environment variable
  No package 'cairo' found
  Command '['pkg-config', '--print-errors', '--exists', 'cairo >= 1.15.10']' returned non-zero exit status 1.
  ----------------------------------------
  ERROR: Failed building wheel for pycairo
ERROR: Could not build wheels for pycairo which use PEP 517 and cannot be installed directly

Any hints to solve this issue?

(My system is ubuntu 16.04)

Thanks!

Bug in feature aggregation

Hi, @gordicaleksa . Thank you for your implementation of GAT.

I'm new to GNNs so I'm not sure whether I understood your code correctly, but I think there is a bug in the feature aggregation in your GATLayer. The direction of aggregation appears as target->source.

In your implementation 1, attention scores are calculated as follows:

# shape = (NH, N, 1) + (NH, 1, N) -> (NH, N, N) with the magic of automatic broadcast <3
# In Implementation 3 we are much smarter and don't have to calculate all NxN scores! (only E!)
# Tip: it's conceptually easier to understand what happens here if you delete the NH dimension
all_scores = self.leakyReLU(scores_source + scores_target.transpose(1, 2))
# connectivity mask will put -inf on all locations where there are no edges, after applying the softmax
# this will result in attention scores being computed only for existing edges
all_attention_coefficients = self.softmax(all_scores + connectivity_mask)

The three dimensions of all_attention_coefficients mean (head, src, tgt), and you apply softmax on dim=-1 i.e. dim=2, making the scores sum up to 1 for each attention head and each source node.

And then in aggregation:

# shape = (NH, N, N) * (NH, N, FOUT) -> (NH, N, FOUT)
out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj)

Let's ignore the head dimension, then this calculates:
out_nodes_features[i,:] = sum_over_j(all_attention_coefficients[i,j], nodes_features_proj[j,:])
The definition of all_attention_coefficients is (head, src, tgt), and nodes_features_proj (node, feat), where "node" corresponds to "tgt" dim, so out_nodes_features's 2 dims should mean (src, feat).

All of the code above has done the following: calculate attention score for each node as source of edge, and aggregate features of all its neighboring target nodes.
However based on my understanding, the feature aggregation in GAT should be in the opposite direction: collecting source nodes into each target.

The implementation 2 also comes with the same problem. I'm still working to understand impl 3 so I don't know if the big persists.

cairocffi is also necessary in environment.yaml

I built conda env by conda env create -f environment.yaml. But before pip install cairocffi, ig.plot() shows error: Type Error: ploting is not available.
So maybe cairocffi should be include in environment.yaml

    - cairocffi==1.2.0

Training a GAT for Text classification

I have a dataset with 62 classes and what do i have to Train a GNN or GAT? I am confused and also if i have to train one of GNN or GAT how do I convert the dataset to a format which is accepted as input to GNN or GAT

Training a GAT for Graph classification

Hello, author. I have 969 samples, each of which corresponds to the same graph structure and has 6888 nodes. Want to realize graph classification on the basis of your source code, but tried for a long time or did not find a suitable way to deal with data. I don't know how to convert raw data from.csv format to GAT supported such as feat/label/graph_id. npy / graph.json. Your help will be greatly appreciated!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.