Code Monkey home page Code Monkey logo

spygr-segmentation's People

Contributors

sunwng avatar

Stargazers

 avatar  avatar

Watchers

 avatar

spygr-segmentation's Issues

确认

这就是那篇论文的代码,感谢大佬

训练过程出现NaN/Inf

你好,我将这个代码里面的GR_Module修改之后加入到我的模型之中,发现在训练过程中一次反向传播之后会出现Nan/Inf的情况。我使用的AdamW优化器,LR=1e-5。尝试改小LR并不能解决问题。在训练代码中加入了梯度裁剪也没有解决问题。

训练过程中打印梯度发现主要是phi_conv这个卷积层的梯度出现nan导致模型训练出错。

不知道你之前有没有遇到类似的问题,如果有,你是怎么解决的。希望得到你的回复,谢谢。

最后附上我修改之后的模块代码。

class GRModule(nn.Module):
    def __init__(self, channel, graph_feature=64):
        super(GRModule, self).__init__()

        self.channel = channel
        self.M = graph_feature

        self.phi_conv = nn.Sequential(
            nn.Conv2d(channel, self.M, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True)
        )
        self.glob_pool_conv = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(channel, self.M, kernel_size=1, stride=1, padding=0, bias=False)
        )

        self.graph_weight = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        bs = x.shape[0]
        # 计算得到 $\phi x$ 和 $\phi^T x$
        x_phi_conv = self.phi_conv(x)
        x_phi = x_phi_conv.view([x_phi_conv.shape[0], -1, self.M])
        x_phi_T = x_phi_conv.view([x_phi_conv.shape[0], self.M, -1])

        # 计算得到 $\Lambda$
        x_glob_pool_conv = self.glob_pool_conv(x)
        x_glob_diag = torch.zeros(bs, self.M, self.M).to('cuda' if x_phi.is_cuda else 'cpu')
        for i in range(bs):
            x_glob_diag[i, :, :] = torch.diag(x_glob_pool_conv[i, :, :, :].reshape(1, self.M))
        
        # $\tilde A = \phi \Lambda \phi^T$
        A_tilde = torch.matmul(torch.matmul(x_phi, x_glob_diag), x_phi_T)
        
        # $\tilde D_{ii} = \sum_j A_{ij}$
        D_sqrt_inv = torch.zeros_like(A_tilde).to('cuda' if A_tilde.is_cuda else 'cpu')
        diag_sum = torch.sum(A_tilde, 2)

        for i in range(bs):
            diag_sqrt = 1.0 / torch.sqrt(diag_sum[i, :])
            diag_sqrt[torch.isnan(diag_sqrt)] = 0
            diag_sqrt[torch.isinf(diag_sqrt)] = 0
            D_sqrt_inv[i, :, :] = torch.diag(diag_sqrt)

        # $I$
        I = torch.eye(D_sqrt_inv.shape[1]).to('cuda' if A_tilde.is_cuda else 'cpu')
        I = I.repeat(bs, 1, 1)

        # $\tilde L = I - \tilde D_{-\frac{1}{2}} \tilde A \tilde D_{-\frac{1}{2}}$
        L_tilde = I - torch.matmul(torch.matmul(D_sqrt_inv, A_tilde), D_sqrt_inv)

        # $\sigma(\tilde L X W)$
        out = torch.matmul(L_tilde, x.reshape(bs, -1, self.channel))
        out = out.reshape(bs, self.channel, x.shape[2], x.shape[3])
        out = self.graph_weight(out)

        return out

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.