Code Monkey home page Code Monkey logo

dir-gnn's Introduction

Discovering Invariant Rationales for Graph Neural Networks 🔥

Overview

DIR (ICLR 2022) aims to train intrinsic interpretable Graph Neural Networks that are robust and generalizable to out-of-distribution datasets. The core of this work lies in the construction of interventional distributions, from which causal features are identified. See the quick lead-in below.

  • Q: What are interventional distributions?

    They are basically the distributions when we intervene on one variable or a set of variables in the data generation process. For example, we could intervene on the base graph (highlighted in green or blue), which gives us multiple distributions:

  • Q: How to construct the interventional distributions?
    We design the following model structure to do the intervention in the representation space, where the distribution intervener is in charge of sampling one subgraph from the non-causal pool and fixing it at one end of the rationale generator.

  • Q: How can these interventional distributions help us approach the causal features for rationalization?

    Here is the simple philosophy: No matter what values we assign to the non-causal part, the class label is invariant as long as we observe the causal part. Intuitively, interventional distributions offer us "multiple eyes" to discover the features that make the label invariant upon interventions. And we propose the DIR objective to achieve this goal

    See our paper for the formal description and the principle behind it.

Installation

Note that we require 1.7.0 <= torch_geometric <= 2.0.2. Simple run the cmd to install the python environment (you may want to change cudatoolkit accordingly based on your cuda version) or see requirements.txt for the packages.

sh setup_env.sh
conda activate dir

Data download

  • Spurious-Motif: this dataset can be generated via spmotif_gen/spmotif.ipynb.
  • Graph-SST2: this dataset can be downloaded here.
  • MNIST-75sp: this dataset can be downloaded here. Download mnist_75sp_train.pkl, mnist_75sp_test.pkl, and mnist_75sp_color_noise.pt to the directory data/MNISTSP/raw/.

Run DIR

The hyper-parameters used to train the intrinsic interpretable models are set as default in the argparse.ArgumentParser in the training files. Feel free to change them if needed. We use separate files to train each dataset.

Simply run python -m train.{dataset}_dir to reproduce the results in the paper.

Common Questions:

How does the Rationale Generator update its parameters?: #7

Reference

@inproceedings{
    wu2022dir,
    title={Discovering Invariant Rationales for Graph Neural Networks},
    author={Ying-Xin Wu and Xiang Wang and An Zhang and Xiangnan He and Tat-seng Chua},
    booktitle={ICLR},
    year={2022},
}

dir-gnn's People

Contributors

wuyxin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

dir-gnn's Issues

How the practical objectives match the theoretical ones?

Thank you very much for sharing such a good paper!
But I am a little confused by the practical objectives (Equation 11, 10, and 9), may be due to the reason that I have missed something.
I have read Appendix E, but I still have no idea about how Equation 11 helps to achieve Equation 4 or 3. Could you please give me some help?

creating interventional distributions

Hello, thank you for your excellent work. I have some confusion regarding creating interventional distributions. The paper mentions creating a memory bank from which to sample for interventions, but I haven't found relevant code in yours. Can you provide some help to me? Thank you.

Questions about Equation 6.

Thank you very much for providing such a good paper.
But I am a little confused by Equation 6. As you first get a soft maks socre (M) for each edge, and you select the top-r edges in Equation 6. Are The selected edges soft values between [0,1] or just binarized {0,1} values? If they are binarized values, it seems that the gradients will be detached. But if they are continuous values, it will be strange to select top-r of them.

conda env yaml

I can't install packages from the requirements.txt you provided because there are some package conflicts, could you give an updated version or provide a conda env yaml (conda env export > dir.yaml) if you use conda?

Segmentation fault (core dumped)

When I import some packages just like ogb and torch_geometric, the code will has wrong with segmentation fault(core dumped). I do not know why?

Question about the Spurious-Motif dataset

Hi, I'm really interested in this work and I have a question about the line 45-48 of DIR-GNN/datasets/spmotif_dataset.py

x = torch.zeros(node_idx.size(0), 4)
index = [i for i in range(node_idx.size(0))]
x[index, z] = 1
x = torch.rand((node_idx.size(0), 4))

I'm wondering if the node feature randomized or generated from the role id ?

baseline in Table1

您好,我想复现一下你们论文中的Table1 中的baseline,方便发布一份,或者给个链接吗?

The issues of calculating precision@K

The way you calculate the precision@K may be wrong since you should divide the K at the last line instead of the num_gd. I also want to ask what are the meanings of C and E.

num_gd = int(ground_truth_mask[C: C + E].sum())
pred = pred_weight[C:C + E]
_, indices_for_sort = pred.sort(descending=True, dim=-1)
idx = indices_for_sort[:num_gd].detach().cpu().numpy()
precision.append(ground_truth_mask[C: C + E][idx].sum().float()/num_gd)](url)

Thank you very much!

How to update the trainable parameters for edge models & How did you set the edge scores as masks in the message-passing?

Hi, thanks for sharing the code.

I noticed that you used self.mlp to work on pairs of node representations to obtain the edge scores. Then this edge score is used to select the edges of the causal subgraph.
image

However, there are two confusion questions. (1) you mentioned M_ij is calculated by sigmoid(Z_i^T Z_j) rather than using a parametric network to gett the mask matrix. (2) As far as I am concerned, the parameters of this edge model self.mlp cannot be backpropagated during the training. In other words, its parameters are fixed. Can you please give me some more explanations so that I can understand better how this edge model works?

requirements.txt

There is an absolute path in the requirements file. Can you regenerate the requirements file without an absolute path? Or you can use this command
pip list --format=freeze > requirements.txt
thanks

Dataset MolHIV

Hi, thanks for sharing the code. However, I wonder whether you miss out on one of the four datasets used in your paper: the MolHIV dataset. Can you please provide some information regarding how you use it and the relevant code?

About the size of base graph

Thank you for your work. I want to ask why do you change the size of base graph, and why the sizes of base graphs are different in the training and testing?

parameters in paper

Hi, I would like to ask you what are the parameters of the spurious-motif dataset in your paper, i.e. the parameters in the graph_stats(base_num) function under the path spmotif_gen/spmotif.ipynb.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.