Comments (4)
The kernels for pseudo
grad computation exist, and their tests pass. I would have expected this to work (although I have never used that feature in one of my projects). Can you show me a minimal example to reproduce the error?
from pytorch_spline_conv.
Hey, thanks for helping out. Here is a minimal code that will reproduce the error and I am using torch version 1.4.0 and torch_geometric version is 1.3.2. If I detach the pseudo coordinates then the network produces no error but the gradients don't flow back beyond the point of detachment. Let me know what you think :).
import torch
import torch.nn.functional as F
from torch_geometric.nn import SplineConv
import torch_geometric
from torch_geometric.data import DataLoader, Data
n_epochs = 2
train_dataset = []
for i in range(10):
x = torch.ones((3,1))
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
pos = torch.rand((3,3))
edge_attr = torch.rand((4,3))
target = torch.randint(0,2, (3,34))
data = Data(x = x, edge_index = edge_index, edge_attr = edge_attr, pos = pos, target = target)
train_dataset.append(data)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.stn_conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
self.stn_conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
self.stn_lin1 = torch.nn.Linear(64, 256)
self.stn_lin2 = torch.nn.Linear(256, 9)
self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
self.lin1 = torch.nn.Linear(64, 256)
self.lin2 = torch.nn.Linear(256, 34)
def forward(self, data):
x, edge_index, pseudo= data.clone().x, data.clone().edge_index, data.clone().edge_attr
x = F.elu(self.stn_conv1(x, edge_index, pseudo))
x = F.elu(self.stn_conv2(x, edge_index, pseudo))
x = F.elu(self.stn_lin1(x))
x = F.dropout(x, training=self.training)
x = self.stn_lin2(x)
x = torch.mean(x, dim=0)
x = F.normalize(x.view(9), p=2, dim=0)
rotation_mat = x.view(3,3)
update_pos = torch.matmul(data.pos,rotation_mat).squeeze(0)
update_x = data.x.new_ones(data.x.size())
row, col = data.edge_index
cart = update_pos[col] - update_pos[row]
#cart = F.softmax(cart, dim=0).detach() # This passes through but the gradients don't flow back to stn_conv1,..., stn_lin2
cart = F.softmax(cart, dim=0)
x = F.elu(self.conv1(update_x, edge_index, cart))
x = F.elu(self.conv2(x, edge_index, cart))
x = F.elu(self.conv3(x, edge_index, cart))
x = F.elu(self.lin1(x))
x = F.dropout(x, training=self.training)
x = self.lin2(x)
return (x)
def train(epoch):
model.train()
total_loss = 0
for data in train_loader:
target = data.target.to(device).float()
pred = model(data.to(device))
optimizer.zero_grad()
loss = criterion(pred, target)
loss.backward()
optimizer.step()
# print("---gradients after backprop---")
# for n, p in model.named_parameters():
# print(n, p.grad)
total_loss += loss.item()
return(total_loss/len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device).float()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
for epoch in range(1, n_epochs):
train_loss = train(epoch)
print('Epoch: {:02d}, train_loss: {:.4f}'.format(epoch, train_loss))
from pytorch_spline_conv.
Works for me on latest PyTorch Geometric (but I would have assumed that it also works on earlier versions).
from pytorch_spline_conv.
I updated PyTorch Geometric and it works fine now. Thanks for the quick replies!
from pytorch_spline_conv.
Related Issues (20)
- Possibility of creating a tutorial? HOT 6
- Tutorial Segmentation Fault HOT 2
- Performing Spline Convolution for evaluating Spline Surface HOT 1
- Only datatype float excepted HOT 8
- Installation issues HOT 1
- Using spline_conv to approximate gradients on a graph HOT 3
- graph preprocessing for SplineConv layer HOT 2
- Getting in depth understanding of the spline kernels HOT 2
- Cannot import torch-spline-conv when installed from pip wheel in Torch 1.9.0 HOT 19
- Error importing CPU wheel on machine with CUDA HOT 1
- ImportError occurs in Google Colab HOT 2
- CUDA libraries not generated when building pytorch-spline-conv HOT 7
- Cannot install pyg without pytorch_spline_conv in conda HOT 3
- Plotting Learned Filters HOT 2
- ImportError: cannot import name 'SplineConv' from 'torch_spline_conv' HOT 1
- CUDA error: an illegal memory access HOT 6
- Question: Similiar to Receptive Field HOT 4
- GLIBC incompatibility issue on RHEL8 HOT 5
- is:issue is:open SplineConv issue with 1024 x 1024 x 3 image HOT 14
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch_spline_conv.