rexying / gnn-model-explainer Goto Github PK
View Code? Open in Web Editor NEWgnn explainer
License: Apache License 2.0
gnn explainer
License: Apache License 2.0
Hi,
It seems the default setting of objective for training GCN model for 'benchmark datasets' (i.e, REDDIT, Mutag, etc) is 'Graph Classification'. It is far different from 'syn1~5' dataset which are set as 'Node Classification' in codes by default.
So I tried to generate explanations for REDDIT & Mutag with :
python -W ignore explainer_main.py --bmname=REDDIT-BINARY --gpu --graph-mode
python -W ignore explainer_main.py --bmname=Mutagenicity --gpu --graph-mode
but both trial gives me errors:
Traceback (most recent call last):
File "explainer_main.py", line 316, in
main()
File "explainer_main.py", line 280, in main
explainer.explain_graphs(graph_indices=[1, 2, 3, 4])
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/explainer/explain.py", line 363, in explain_graphs
masked_adj = self.explain(node_idx=0, graph_idx=graph_idx, graph_mode=True)
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/explainer/explain.py", line 172, in explain
node_idx_new, epoch, label=single_subgraph_label
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/explainer/explain.py", line 966, in log_masked_adj
args=self.args,
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/utils/io_utils.py", line 316, in log_graph
raise Exception("empty edge")
Exception: empty edge
Traceback (most recent call last):
File "explainer_main.py", line 316, in
main()
File "explainer_main.py", line 280, in main
explainer.explain_graphs(graph_indices=[1, 2, 3, 4])
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/explainer/explain.py", line 396, in explain_graphs
args=self.args
File "/mnt/server5_hard1/seungjun/gnn-model-explainer/utils/io_utils.py", line 347, in log_graph
alpha=0.8,
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/networkx/drawing/nx_pylab.py", line 123, in draw
draw_networkx(G, pos=pos, ax=ax, **kwds)
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/networkx/drawing/nx_pylab.py", line 336, in draw_networkx
draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/networkx/drawing/nx_pylab.py", line 684, in draw_networkx_edges
alpha=alpha,
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/matplotlib/collections.py", line 1378, in init
colors = mcolors.to_rgba_array(colors)
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/matplotlib/colors.py", line 341, in to_rgba_array
return np.array([to_rgba(cc, alpha) for cc in c])
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/matplotlib/colors.py", line 341, in
return np.array([to_rgba(cc, alpha) for cc in c])
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/matplotlib/colors.py", line 189, in to_rgba
rgba = _to_rgba_no_colorcycle(c, alpha)
File "/home/cilab5/anaconda3/envs/sj_gnn_explainer/lib/python3.6/site-packages/matplotlib/colors.py", line 263, in _to_rgba_no_colorcycle
raise ValueError(f"Invalid RGBA argument: {orig_c!r}")
ValueError: Invalid RGBA argument: tensor(1., dtype=torch.float64)
It would be really grateful if I can find out how to generate explanations for 'benchmark datasets' :)
Describe the bug
I downloaded MUTA's data and put it into directory data/. , then I replaced 'dataset = syn1' with 'bmname=Mutagenicity' in train.py&config.py and run train.py successfully. But when I run explain_main.py , error occored. I directly followed the step of readme but failed
Describe the bug
A clear and concise description of what the bug is.
To Reproduce
Steps to reproduce the behavior:
Platform (please complete the following information):
Additional context
Add any other context about the problem here.
Hi, @RexYing.Thanks for the contribution!
It seems that feat_mask_ent_loss
is not added to the total loss (in explainer/explain.py
). Could you make some explainations or is there anything I miss?
Describe the bug
Hi @RexYing I could not find where it is computing the accuracy reported in Table 1 in GnnExplainer paper. I looked into all the code and jupyter notebook but was not able to locate. Could you help me to find it? Thank you very much in advance
To Reproduce
Steps to reproduce the behavior:
Platform (please complete the following information):
Additional context
Add any other context about the problem here.
This was previously asked but has not been addressed. Is it possible currently to run the GNNExplainer on our own data and models? If so, is it possible to get an example of how to tweak the file to allow this?
Thanks for sharing your amazing work.
README.md contains a link to https://observablehq.com/d/00c5dc74f359e7a1. This seems to return 404.
I can't find the d3-based visualisation in the repository. Are they available anywhere?
Describe the bug
Hi,
I am encountering a bug while training the GNN model with Tox21 data. After the model training finishes, I get an error as below. The checkpoint is not getting saved as well. What should I do?
Train accuracy: 0.9083333333333333
Validation accuracy: 0.8970588235294118
Test accuracy: 0.9105392156862745
Best val result: {'epoch': 700, 'loss': tensor(0.1380, device='cuda:0', grad_fn=<DivBackward0>), 'acc': 0.9044117647058824}
Test result: {'prec': 0.8022022213711222, 'recall': 0.72627360451945, 'acc': 0.9105392156862745, 'epoch': 999}
Traceback (most recent call last):
File "train.py", line 1179, in <module>
main()
File "train.py", line 1156, in main
benchmark_task(prog_args, writer=writer)
File "train.py", line 932, in benchmark_task
writer=writer,
File "train.py", line 247, in train
plt.savefig(io_utils.gen_train_plt_name(args), dpi=600)
File "C:\Users\kacpe\Desktop\Github\gnn-model-explainer\utils\io_utils.py", line 667, in gen_train_plt_name
return "results/" + io_utils.gen_prefix(args) + ".png"
NameError: name 'io_utils' is not defined
To Reproduce
Steps to reproduce the behavior:
Platform (please complete the following information):
Hi
I am using your great code! I have slightly modified your models.py
to fit my regression task, but when I run the train.py
, error occurred:
Traceback (most recent call last):
File "train.py", line 1273, in
main()
File "train.py", line 1250, in main
benchmark_task(prog_args, writer=writer)
File "train.py", line 1024, in benchmark_task
writer=writer,
File "train.py", line 190, in train
ypred, att_adj = model(h0, adj, batch_num_nodes, assign_x=assign_input) #use the 'forward' that is in model class
File "/home/omnisky/venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/omnisky/Documents/individual_variance/gnn-model-explainer/models.py", line 570, in forward
self.assign_pred_modulesi
File "/home/omnisky/venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/home/omnisky/venv/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/home/omnisky/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 1372, in linear
output = input.matmul(weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat2' in call to _th_mm
Can you give some advice about that? Thank you for your kind assistance!
Zeng
Hi, I have read and tried the codes, and now it can work on my graph-level GNN model in multi-instance case.
The input graph in my project is a spatiotemporal graph, which connects the subgraph in each frame together as a whole graph.
Then I realize the difference between my project and the applications of GNNExplainer. The labels used in GNNExplainer are the node labels, while I am more interested to explain the relation between the GNN model and the whole graph’s label.
Is it possible that if I change the label in GNNExplainer to the graph label, can GNNExplainer then generate a meaningful explanation? Considering the mutual information in GNNExplainer, I think it can work, but I am not confident enough.
Could someone please give me some advice? This will be very helpful.
Dear Authors,
Thanks for sharing the codes. I am trying to reproduce the results for comparison. However, some settings are not clear.
For example, what is the architecture of GNN? The same for all synthetic datasets?
Which nodes are explained in quantitative evaluation? only range(400,700,5), or all nodes inside motifs?
Describe the bug
Please could we add pretrained models? I would love to play around with these. Also, could we add Tox21 dataset training possible please?
Also - the visualisation page is down.
This is really great work - I don't mean to be complaining at all. Keep it up!
To Reproduce
python train.py --dataset Tox21_AHR
does not work, as there is no config in train.py
for the Tox21 dataset.
Platform (please complete the following information):
Based on the code in gengraph
and synthetic_structsim
modules, I gather that when running training with the synthetic datasets, input graphs are generated on the fly. Based on the randomness routines present in methods such as synthetic_structsim.build_graph()
it is evident that every invocation of the method will likely generate a completely different instance of the input graph.
Hello,
I am trying to use explain_pyg.py to train the pytorch geometric GCN model specified in models_pyg.py to perform the synthetic task 1. I am using the default args in configs.py but the training seems to get stuck at ~15% accuracy and never gets any better. I did minor hyperparameter tuning, but still no luck. This is significantly different from the train.py and models.py which seem to approach 98% accuracy easily after a few hundred epochs. Is there some issue with the pytorch geometric implementation that I am missing? Has anyone had any luck training pytorch geometric models on the synthetic data? Thanks!
Hi there,
I am wondering is it possible to apply this on hetero GNN? Thanks.
hi,
I set normalize_adj=True in preprocess_input_graph function, then I use the normalized adj to train the GCN model and explain model. When I explain node, I find the loss will become very large and the explainer.masked_adj is too small for each edge to visualize.
Do I have that right? By the way, my graph is a weighted directed graphs.
I also want to konw why there is a self.coeffs param in ExplainModule?
Thanks.
Hi,
This is a very interesting work!
The repo provides several datasets to test GNNExplainer. However, it is not obvious to me how a user can run it on his/her own model and dataset. Could you please explain how to do that?
Best,
Jingxuan
What does the term lap_loss indicate intutively, like what are we trying to optimise or capture here ?
D = torch.diag(torch.sum(self.masked_adj[0], 0))
m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx]
L = D - m_adj
pred_label_t = torch.tensor(pred_label, dtype=torch.float)
if self.args.gpu:
pred_label_t = pred_label_t.cuda()
L = L.cuda()
if self.graph_mode:
lap_loss = 0
else:
lap_loss = (self.coeffs["lap"]
* (pred_label_t @ L @ pred_label_t)
/ self.adj.numel()
)
Hi
I'm using GCN to classify the node on datasets "KarateClub", then I using GNNExplainer to explain node 12. However, when I explain the node twice, GNNExplainer gives me two different subgraph, and them have different node_feat_mask&edge_mask.
I‘m so confused about the different explanations generated from the same trained model.
run code on jupyter notebook
cell one:
`
import torch
from torch_geometric.datasets import KarateClub
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GNNExplainer
from torch_geometric.datasets import KarateClub
import networkx as nx
import matplotlib.pyplot as plt
dataset = KarateClub()#torch_geometric.datasets
class Net(torch.nn.Module):
def init(self):
super().init()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
pass
def forward(self, x, edge_index):
#x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
pass
pass
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
x, edge_index = data.x, data.edge_index
for epoch in range(61):
model.train()
optimizer.zero_grad()
out = model(x, edge_index)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
#print('Epoch {} | Loss: {:.4f}'.format(epoch,loss.item()))
model.eval()
_, pred=out.max(dim=1)
#print(pred)
correct = int(pred.eq(data.y).sum().item())
acc = correct / int(data.x.sum())
#print('Accuracy:{:.4f}'.format(acc))
print('Epoch {} | Loss: {:.4f}'.format(epoch,loss.item())+' | Accuracy:{:.4f}'.format(acc))
pass
pass
`
cell two:
explainer = GNNExplainer(model, epochs=60) node_idx = 12 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y) plt.show()
cell three:
explainer = GNNExplainer(model, epochs=61) node_idx = 12 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y,threshold=0.6) plt.show()
Dear authors,
Thanks for sharing the code for this nice paper!
I am particularly interested in the how you "obtain a global explanation of the class, which can shed light on how the identified structure for a given node is related to a prototypical structure unique".
As mentioned in your paper, it is done via, firstly, identifying a reference node for class c, v_c, and its associated important computation subgraph G_S(v_c), then aligning each of the identified computation subgraphs for all nodes in class c to the reference G_S(v_c). Utilizing the idea in the context of differentiable pooling [40].
I would like to reproduce your result in generating multi-instance / global explanation, would you also share that part of codes please? Many thanks!
The link of enron dataset you provided refers to a raw data. But in the line 831 of train.py project, it reads the pkl data instead of the raw data. Could you please provided the script transforming the raw data to the processed pickle data?
Hi,
Thanks for sharing your code. I see that in your paper link prediction task is discussed in section 4.4, while the experiments are only conducted on node classification and graph classification. I'm wondering how I can train the model and run the explanation experiment on this task. Is it available in this repo? Thank you for your help!
The link given in the D3 section of the readme file is not working. It says that the notebook does not exist.
Hi,
is this library compatible with graph_nets models?
Describe the bug
hi-union-ppi.tsv and enron_slice_0.pkl are not found.
And when I use --bmname=[REDDIT-BINARY/Mutagenicity/Tox21_AHR],nothing happened,no ckpt exitsted.Where should I put these three datasets when they were downloaded manually.
Platform
Hi, I'm wondering which function did you use to generate the visualization on MUTAG? (Figure 4 and 5 in the paper)
Or if authors have any follow-up on this?
Some other detailed questions:
sorry couldn‘t find them in the repository...
Describe the bug
File requirements.txt fault.
the last line of requirements.txt:
pywidgets==7.5.1
should be changed to:
ipywidgets==7.5.1
to make pip work or pip cannot find the package "pywidgets==7.5.1".
[Environment]
Anaconda3
pip 20.0.2
Ubuntu 18.04
Hi @RexYing!
Thank you for the great package, I really enjoyed reading the paper as well.
I'm thinking of applying your explanation model to a deep graph CNN that I trained on a graph classification task making use of stellargraph package (Repo here).
The architecture contains a pooling layer/fully connected layer and it wasn't clear if your method would work with this type of layers?
Best,
Francisco
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.