sunwng / spygr-segmentation Goto Github PK
View Code? Open in Web Editor NEWSpyGR code after paper review
SpyGR code after paper review
请问这片代码为CVPR2020的Spatial Pyramid Based Graph Reasoning for Semantic Segmentation这篇吗
这就是那篇论文的代码,感谢大佬
你好,我将这个代码里面的GR_Module修改之后加入到我的模型之中,发现在训练过程中一次反向传播之后会出现Nan/Inf的情况。我使用的AdamW优化器,LR=1e-5。尝试改小LR并不能解决问题。在训练代码中加入了梯度裁剪也没有解决问题。
训练过程中打印梯度发现主要是phi_conv这个卷积层的梯度出现nan导致模型训练出错。
不知道你之前有没有遇到类似的问题,如果有,你是怎么解决的。希望得到你的回复,谢谢。
最后附上我修改之后的模块代码。
class GRModule(nn.Module):
def __init__(self, channel, graph_feature=64):
super(GRModule, self).__init__()
self.channel = channel
self.M = graph_feature
self.phi_conv = nn.Sequential(
nn.Conv2d(channel, self.M, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True)
)
self.glob_pool_conv = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(channel, self.M, kernel_size=1, stride=1, padding=0, bias=False)
)
self.graph_weight = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
bs = x.shape[0]
# 计算得到 $\phi x$ 和 $\phi^T x$
x_phi_conv = self.phi_conv(x)
x_phi = x_phi_conv.view([x_phi_conv.shape[0], -1, self.M])
x_phi_T = x_phi_conv.view([x_phi_conv.shape[0], self.M, -1])
# 计算得到 $\Lambda$
x_glob_pool_conv = self.glob_pool_conv(x)
x_glob_diag = torch.zeros(bs, self.M, self.M).to('cuda' if x_phi.is_cuda else 'cpu')
for i in range(bs):
x_glob_diag[i, :, :] = torch.diag(x_glob_pool_conv[i, :, :, :].reshape(1, self.M))
# $\tilde A = \phi \Lambda \phi^T$
A_tilde = torch.matmul(torch.matmul(x_phi, x_glob_diag), x_phi_T)
# $\tilde D_{ii} = \sum_j A_{ij}$
D_sqrt_inv = torch.zeros_like(A_tilde).to('cuda' if A_tilde.is_cuda else 'cpu')
diag_sum = torch.sum(A_tilde, 2)
for i in range(bs):
diag_sqrt = 1.0 / torch.sqrt(diag_sum[i, :])
diag_sqrt[torch.isnan(diag_sqrt)] = 0
diag_sqrt[torch.isinf(diag_sqrt)] = 0
D_sqrt_inv[i, :, :] = torch.diag(diag_sqrt)
# $I$
I = torch.eye(D_sqrt_inv.shape[1]).to('cuda' if A_tilde.is_cuda else 'cpu')
I = I.repeat(bs, 1, 1)
# $\tilde L = I - \tilde D_{-\frac{1}{2}} \tilde A \tilde D_{-\frac{1}{2}}$
L_tilde = I - torch.matmul(torch.matmul(D_sqrt_inv, A_tilde), D_sqrt_inv)
# $\sigma(\tilde L X W)$
out = torch.matmul(L_tilde, x.reshape(bs, -1, self.channel))
out = out.reshape(bs, self.channel, x.shape[2], x.shape[3])
out = self.graph_weight(out)
return out
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.