felixopolka / stgcn-pytorch Goto Github PK
View Code? Open in Web Editor NEW๐ Implementation of spatio-temporal graph convolutional network with PyTorch
License: MIT License
๐ Implementation of spatio-temporal graph convolutional network with PyTorch
License: MIT License
Hi,
I am confused with the Adjacency Matrix in your adj_mat.npy
. It doesn't seem to be an adjacency matrix.
Best
The whole validation dataset (batch_size = 6,854
) was passed in after a training epoch. Then it caused 'CUDA out of memory' ERROR.
This is how to fix the problem in main.py
. ( The original lines are annotated. )
with torch.no_grad():
net.eval()
val_input = val_input.to(device=args.device)
val_target = val_target.to(device=args.device)
# out = net(A_wave, val_input)
# --------------------------------
tmp_val_losses = []
tmp_maes = []
for i in range(0, val_input.shape[0], batch_size):
out = net(A_wave, val_input[i:i+batch_size, ...])
loss = loss_criterion(out, val_target[i:i+batch_size, ...]).to(device="cpu")
tmp_val_losses.append(np.ndarray.item(loss.detach().numpy()))
out_unnormalized = out.detach().cpu().numpy() * stds[0] + means[0]
target_unnormalized = val_target[i:i+batch_size, ...].detach().cpu().numpy() * stds[0] + means[0]
mae = np.mean(np.absolute(out_unnormalized - target_unnormalized))
tmp_maes.append(mae)
val_loss = sum(tmp_val_losses) / len(tmp_val_losses)
validation_losses.append(val_loss)
mae = sum(tmp_maes) / len(tmp_maes)
validation_maes.append(mae)
# --------------------------------
# val_loss = loss_criterion(out, val_target).to(device="cpu")
# validation_losses.append(np.ndarray.item(val_loss.detach().numpy()))
# out_unnormalized = out.detach().cpu().numpy() * stds[0] + means[0]
# target_unnormalized = val_target.detach().cpu().numpy() * stds[0] + means[0]
# mae = np.mean(np.absolute(out_unnormalized - target_unnormalized))
# validation_maes.append(mae)
out = None
val_input = val_input.to(device="cpu")
val_target = val_target.to(device="cpu")
print("Training loss: {}".format(training_losses[-1]))
print("Validation loss: {}".format(validation_losses[-1]))
print("Validation MAE: {}".format(validation_maes[-1]))
Is there someone who also has this problem?
Why we set number of input and output channel to 64?
Why we set number of spatial channel to 16?
We predict the traffic speed at each node, we can say it's (207 * 1). So what's the number I should use for the input and output channel if I want to predict the traffic flow from each node to each node which means (207 * 207)?
Thanks a lot :-)
X = X.transpose((1, 2, 0))
Why the number of features (observation) is 2 at a specific time for each station?
When running main.py, a runtime error occurs: one of the variables needed for gradient computation has been modified by an inplace operation. What are the possible lines that cause this issue? Is it line 83 (torch.einsum) in stgcn.py?
Hi,
Thanks for your work,
BTW, could I use the dataset from https://github.com/Davidham3/STGCN on your work?
Hi,
I found no related codes about 1st-order approximation in your product. Does it mean you only provide the Chebyshev approx in your code?
Best
What features mean in node_values.npy?
In the tensorflow implementation, the temporal_conv_layer result is the product of conv and sigmoid.
โ return (x_conv[:, :, :, 0:c_out] + x_input) * tf.nn.sigmoid(x_conv[:, :, :, -c_out:]) โ
could you explain why the add operation is used here?
temp = self.conv1(X) + torch.sigmoid(self.conv2(X))
out = F.relu(temp + self.conv3(X))
Hi author,
Thank you for your STGCN of PyTorch version. I found that the matrix in adj_mat.npy is not symmetric, but the matrix should be Wij=exp(-dij^2/sigma^2) (if Wijโ 0). So why is that not symmetric?
Best wishes
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.