Code Monkey home page Code Monkey logo

Comments (4)

rusty1s avatar rusty1s commented on June 7, 2024

The kernels for pseudo grad computation exist, and their tests pass. I would have expected this to work (although I have never used that feature in one of my projects). Can you show me a minimal example to reproduce the error?

from pytorch_spline_conv.

sofitiwari avatar sofitiwari commented on June 7, 2024

Hey, thanks for helping out. Here is a minimal code that will reproduce the error and I am using torch version 1.4.0 and torch_geometric version is 1.3.2. If I detach the pseudo coordinates then the network produces no error but the gradients don't flow back beyond the point of detachment. Let me know what you think :).

import torch
import torch.nn.functional as F
from torch_geometric.nn import SplineConv
import torch_geometric
from torch_geometric.data import DataLoader, Data

n_epochs = 2
train_dataset = []
for i in range(10):
    x = torch.ones((3,1))
    edge_index = torch.tensor([[0, 1, 1, 2],
                            [1, 0, 2, 1]], dtype=torch.long)
    pos = torch.rand((3,3))
    edge_attr = torch.rand((4,3))
    target = torch.randint(0,2, (3,34))
    data = Data(x = x, edge_index = edge_index, edge_attr = edge_attr, pos = pos, target = target)
    train_dataset.append(data)

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.stn_conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
        self.stn_conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
        self.stn_lin1 = torch.nn.Linear(64, 256)
        self.stn_lin2 = torch.nn.Linear(256, 9)
        self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
        self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
        self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
        self.lin1 = torch.nn.Linear(64, 256)
        self.lin2 = torch.nn.Linear(256, 34) 

    def forward(self, data):
        x, edge_index, pseudo= data.clone().x, data.clone().edge_index, data.clone().edge_attr
        x = F.elu(self.stn_conv1(x, edge_index, pseudo))
        x = F.elu(self.stn_conv2(x, edge_index, pseudo))
        x = F.elu(self.stn_lin1(x))
        x = F.dropout(x, training=self.training)
        x = self.stn_lin2(x)
        x = torch.mean(x, dim=0)
        x = F.normalize(x.view(9), p=2, dim=0)
        rotation_mat = x.view(3,3)
        update_pos = torch.matmul(data.pos,rotation_mat).squeeze(0) 
        update_x = data.x.new_ones(data.x.size())
        row, col = data.edge_index 
        cart = update_pos[col] - update_pos[row] 
        #cart = F.softmax(cart, dim=0).detach() # This passes through but the gradients don't flow back to stn_conv1,..., stn_lin2
        cart = F.softmax(cart, dim=0)
        x = F.elu(self.conv1(update_x, edge_index, cart))
        x = F.elu(self.conv2(x, edge_index, cart))
        x = F.elu(self.conv3(x, edge_index, cart))
        x = F.elu(self.lin1(x))
        x = F.dropout(x, training=self.training)
        x = self.lin2(x)
        return (x)

def train(epoch):
    model.train()
    total_loss = 0
    for data in train_loader:
        target = data.target.to(device).float()
        pred = model(data.to(device))
        optimizer.zero_grad()
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()
        # print("---gradients after backprop---")
        # for n, p in model.named_parameters():
        #     print(n, p.grad)
        total_loss += loss.item()
    return(total_loss/len(train_dataset))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device).float()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(1, n_epochs):
    train_loss = train(epoch)	
    print('Epoch: {:02d}, train_loss: {:.4f}'.format(epoch, train_loss))


from pytorch_spline_conv.

rusty1s avatar rusty1s commented on June 7, 2024

Works for me on latest PyTorch Geometric (but I would have assumed that it also works on earlier versions).

from pytorch_spline_conv.

sofitiwari avatar sofitiwari commented on June 7, 2024

I updated PyTorch Geometric and it works fine now. Thanks for the quick replies!

from pytorch_spline_conv.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.