Code Monkey home page Code Monkey logo

wuyxin / refine Goto Github PK

View Code? Open in Web Editor NEW
62.0 4.0 7.0 6.53 MB

Official code of "Towards Multi-Grained Explainability for Graph Neural Networks" (NeurIPS 2021) + Pytorch Implementation of recent attribution methods for GNNs

Home Page: https://proceedings.neurips.cc/paper_files/paper/2021/file/99bcfcd754a98ce89cb86f73acc04645-Paper.pdf

License: MIT License

Python 16.39% Shell 0.06% Jupyter Notebook 83.55%
graph-neural-network multi-grained-explainability neurips2021

refine's Introduction

ReFine: Multi-Grained Explainability for GNNs

This is the official code for Towards Multi-Grained Explainability for Graph Neural Networks (NeurIPS 2021). Besides, we provide highly modularized explainers for Graph Classification Tasks. Some of them are adapted from the image domain. Below is a summary:

Explainer Paper
ReFine Towards Multi-Grained Explainability for Graph Neural Networks
SA Explainability Techniques for Graph Convolutional Networks.
Grad-CAM Explainability Methods for Graph Convolutional Neural Networks.
DeepLIFT Learning Important Features Through Propagating Activation Differences
Integrated Gradients Axiomatic Attribution for Deep Networks
GNNExplainer GNNExplainer: Generating Explanations for Graph Neural Networks
PGExapliner Parameterized Explainer for Graph Neural Network
PGM-Exapliner PGM-Explainer: Probabilistic Graphical Model Explanations for Graph Neural Networks
Screener Causal Screening to Interpret Graph Neural Networks
CXPlain Cxplain: Causal Explanations for Model Interpretation under Uncertainty

Installation

Requirements

  • CPU or NVIDIA GPU, Linux, Python 3.7
  • PyTorch >= 1.5.0, other packages
  1. Pytorch Geometric. Official Download.
# We use TORCH version 1.6.0
CUDA=cu102
TORCH=1.6.0 
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric==1.7.0
  1. Visual Genome (optional). Google Drive Download. This is used for preprocessing the VG-5 dataset and visualizing the generated explanations. Manually download it to the same directory as data. (This package can be accessed by API, but we found it slow to use.) You can still run the other datasets without downloading it.

  2. Other packages

pip install tqdm logging pathlib matplotlib argparse json pgmpy==0.1.11 
# For visualization (optional) 
conda install -c conda-forge rdkit

Datasets

  1. The processed raw data for BA-3motif is available in the data/ folder.
  2. Datasets MNIST, Mutagenicity will be automatically downloaded when training models.
  3. We select and label 4443 graphs from https://visualgenome.org/ to construct the VG-5 dataset. The graphs are labeled with five classes: stadium, street, farm, surfing, forest. Each graph contains regions of the objects as the nodes, while edges indicate the relationships between object nodes. Download the dataset from Google Drive. Arrange the dir as
data ---BA3
 |------VG
        |---raw

Please also cite Visual Genome (bibtex) if you use this dataset.

Train GNNs

We provide the trained GNNs in param/gnns for reproducing the results in our paper. To retrain the GNNs, run

cd gnns/
bash run.sh

The trained GNNs will be saved in param/gnns.

Explaining the Predictions

  1. For global training of PGExplainer and ReFine, run
cd train/
bash run.sh
  1. Load datasets
from utils.dataset import get_datasets
from torch_geometric.data import DataLoader

name = 'ba3'
train_dataset, val_dataset, test_dataset = get_datasets(name=name)
test_loader = DataLoader(test_dataset, batch_size=1)
  1. Instantiate the explainer
from explainers import *

device = torch.device("cuda")
gnn_path = f'param/gnns/{name}_net.pt'

refine = torch.load(f'param/refine/{name}.pt') # load pretrained
refine.remap_device(device)
  1. Explain
for g in test_loadder:
  refine.explain_graph(g, fine_tune=True, 
                      ratio=0.4, lr=1e-4, epoch=20)

For baseline explainers, e.g.,

gnn_explainer = GNNExplainer(device, gnn_path)
gnn_explainer.explain_graph(g,
                           epochs=100, lr=1e-2)
                           
screener = Screener(device, gnn_path)
screener.explain_graph(g)                 
  1. Evaluation & Visualization

Evaluation and visualization are made universal for every explainer. After explaining a single graph, the pair (graph, edge_imp:np.ndarray) is saved as explainer.last_result by default, which is then evaluated or visualized.

ratios = [0.1 *i for i in range(1,11)]
acc_auc = refine.evaluate_acc(ratios).mean()
racall =  refine.evaluate_recall(topk=5)
refine.visualize(vis_ratio=0.3) # visualize the explanation

To evaluate ReFine-FT and ReFine in the testing datasets, run

python evaluate.py --dataset ba3

The results will be included in file results/ba3_results.json, where ReFine-FT.ACC-AUC (ReFine-FT.Recall@5) and ReFine.ACC-AUC (ReFine.Recall@5) are the performances of ReFine-FT and ReFine, respectively.

Citation

Please cite our paper if you find the repository useful.

@inproceedings{refine,
  title={Towards Multi-Grained Explainability for Graph Neural Networks},
  author={Wang, Xiang and Wu, Ying-Xin and Zhang, An and He, Xiangnan and Chua, Tat-Seng},
  booktitle={Proceedings of the 35th Conference on Neural Information Processing Systems},
  year={2021} 
}

refine's People

Contributors

wuyxin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

refine's Issues

BA3-Motif cannot be created correctly.

When I try to train another GNN in BA3-Motif dataset, the 'process' function in 'BA3Motif' class will throw an exception.

Concretely, in 'ba3motif_dataset.py' line 67:

torch.save(self.collate(data_list[800:]),self.processed_paths[0])

Keyerror:15

'collate' function always causes the Keyerror above.

Missed param/filtered directory

Hi,

It looks like that you missed the param/filtered folder.
Could you upload those indices pt files?

Thank you,
Zhaoning

Strange behavior with randomness

On first run, say if I clone the repo and train with command python3 refine_train.py --dataset ba3 --hid 50 --epoch 1 --ratio 0.4 --lr 1e-4, I always get the same ACC-AUC of 0.518. On second and all subsequent runs, this same command gives me ACC-AUC 0.490.

This happens with any number of epochs but is easiest to verify with just one. It seems like something is not quite working with the random seed on first run (although this first run is still seeded as it produces the same result of 0.518 every time), and then once it is trained once, the seed starts working. I even cloned the repo several times again and this same pattern always happened.

I tried to fix this myself but couldn't. Now, this isn't really critical to fix or anything, but I think it's good to mention as it caused quite a bit of confusion for me when testing the code.

Torch version is 1.8.0 because I couldn't get 1.6.0 to work with torch-scatter. Otherwise the setup is the same as in README

ACC-AUC on BA3-Motif

Hi, I want to ask something regarding the performance on BA3-motif.

You reported that Refiner achieve an ACC-AUC of 0.630. However, in your log file, the highest ACC-AUC of Refiner is only 0.612. Is anything wrong with the log you uploaded?

How to deal with HeteroData objects in ReFine?

I encouter problems when I use HAN model to do some graph classification tasks. I read mutag datasets codes which deal with datasets with torch_geometric.data.Data, but HeteroData is used in HAN model.
I want to know if ReFine can deal with HeteroData objects.

Performance of backbone GNN

Hi, thanks for sharing the code, which is well-written. I re-implemented the GNN on all four datasets. I found that GNNs on MINST/VG/BA3 perform closely to what is reported in the paper. However, the accuracy of the testing set is only 80% for MUTAG, but you claim that it can achieve 100% accuracy.

I know this will not influence the effectiveness of your method. But since I want to follow your work, it is better to figure out whether the backbone GNN is strong enough on MUTAG. What do you think of this gap?

Question about the EdgeMask

Hi Ying-Xin,

Thanks for your excellent work and concise codes!

I am troubled with the mask function that used in both PG-Explainer and ReFine. Specifically, the mask function consists of the ARMAConv operations and MLPs in common.py, but I haven't found the descriptions of ARMAConv in paper. So I am wondering that why ARMAConv operations are necessary here?

conv = ARMAConv(in_channels=hid, out_channels=hid)

Looking forward to your reply, much thanks! :)

ReFine training on other GNNs

Hi,

I am trying to use ReFine on other GNNs like GCN.
But I cannot achieve a good result.
Could you tell me which part of the settings should I change to achieve the best performance?

Failed to reproduce the results on Mutag dataset

Hello,

I retrained the downstream model following the documentation (e.g. python mutag_gnn.py) and also the Refine model using the recommended: python refine_train.py --dataset mutag --hid 100 --epoch 100 --ratio 0.4 --lr 1e-3 --batch_size 64

However, at the end of the training I get an AUC of 0.82 which is quite far from the reported 0.955. Am I missing something?

2022-07-01 18:01:23,919 - refine_train.py[line:163] - INFO: ACC:[[0.86 0.81 0.78 0.78 0.77 0.79 0.77 0.8 0.86 1. ]] ACC-AUC: 0.822 Mean P: 0.099 2022-07-01 18:01:44,651 - refine_train.py[line:158] - INFO: Epoch: 98, LR: 0.00001, Ratio: 0.40, Train Loss: -163.027, ValLoss: -1.704

ReFine training

Hello, thanks for making the explainer available! I was trying to do the training step, and get print like so (running python refine_train.py --dataset ba3 --hid 50 --epoch 25 --ratio 0.4 --lr 1e-4):

2021-12-22 11:22:08,770 - refine_train.py[line:95] - INFO: number of graphs(train): 2185
2021-12-22 11:22:08,775 - refine_train.py[line:95] - INFO: number of graphs(val):  398
2021-12-22 11:22:08,777 - refine_train.py[line:95] - INFO: number of graphs(test):  397
2021-12-22 11:23:00,702 - refine_train.py[line:158] - INFO: Epoch: 1, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:23:00,709 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.32 0.35 0.37 0.52 0.77 0.96 1.  ]] ACC-AUC: 0.521 Mean P: nan
2021-12-22 11:23:49,479 - refine_train.py[line:158] - INFO: Epoch: 2, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:24:37,940 - refine_train.py[line:158] - INFO: Epoch: 3, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:24:37,943 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.34 0.37 0.56 0.76 0.97 1.  ]] ACC-AUC: 0.523 Mean P: nan
2021-12-22 11:25:26,978 - refine_train.py[line:158] - INFO: Epoch: 4, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:26:15,519 - refine_train.py[line:158] - INFO: Epoch: 5, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:26:15,522 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.36 0.37 0.52 0.73 0.97 1.  ]] ACC-AUC: 0.519 Mean P: nan
2021-12-22 11:27:03,904 - refine_train.py[line:158] - INFO: Epoch: 6, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:27:52,747 - refine_train.py[line:158] - INFO: Epoch: 7, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:27:52,750 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.39 0.38 0.52 0.77 0.97 0.99]] ACC-AUC: 0.526 Mean P: nan
2021-12-22 11:28:41,475 - refine_train.py[line:158] - INFO: Epoch: 8, LR: 0.00001, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:29:29,920 - refine_train.py[line:158] - INFO: Epoch: 9, LR: 0.00001, Ratio: 0.40, Train Loss: nan, Val Loss: nan
2021-12-22 11:29:29,923 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.35 0.37 0.51 0.74 0.97 1.  ]] ACC-AUC: 0.517 Mean P: nan

Is it normal that the train loss and validation loss are nan? Also the ACC-AUC does not seem to meaningfully improve, so I'm wondering where is the issue...

How to construct VG-5 Dataset.

I have downloaded visual_genome from https://github.com/ranjaykrishna/visual_genome_python_driver

and image_data.json & synsets.json from https://visualgenome.org/api/v0/api_home.html.

But where can I get the image_id.json such as 35.json, hope to get your reply.

Explainer code

Hello,

when can we expect the explainer code to be included? We're quite interested in reproducing the results.

Advising ACC-AUC metrics in the paper

Hi,

Does the ACC-AUC metric in your paper mean ROC-AUC of the prediction of generated subgraphs?
How do you calculate the ACC-AUC?

Thank you,
Zhaoning

Performance of other baseline methods

Hi, I notice that you provide many scripts of other explanation methods. Did you test the performance of those approaches? If you don't mind, can you give the full results of them?
Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.