Code Monkey home page Code Monkey logo

imbalancedlearning's People

Contributors

harboryuan avatar iboing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

imbalancedlearning's Issues

关于Classifier里的batchnorm

作者您好,

现在很多大型网络后面只有一个fc层,比如post-norm的transformer后面的classifier部分完全可以不加norm。论文中没有提及etf classifier里的batchnorm,这个让我联想到Logit Attenuating Weight Normalization(虽然这篇是calibration,但某种程度上姑且也是缓解imbalance问题)和其他几篇对logit做正则的文章,不知这个是可有可无的还是比较重要的?谢谢。

Implementation of dot regression loss is inconsistent with the equation in the paper.

In the paper, dot regression loss is defined as

$$ \mathcal{L}_{DR}(\mathbf{h}, \mathbf{W}^*) = \frac{1}{2\sqrt{E_W E_H}}\left(\mathbf{w}^{*T}_c\mathbf{h} - \sqrt{E_W E_H}\right)^2. $$

However, the implementation in the code is

elif criterion == 'reg_dot_loss':
dot = torch.bmm(output.unsqueeze(1), target.unsqueeze(2)).view(-1) #+ classifier.module.bias[label].view(-1)
with torch.no_grad():
M_length = torch.sqrt(torch.sum(target ** 2, dim=1, keepdims=False))
loss = (1/2) * torch.mean(((dot-(M_length * H_length)) ** 2) / H_length)

which losses the $\sqrt{E_W}$ term in the denominator.

可以提供CUB数据集的训练设置文件和数据处理代码吗?

非常感谢你的工作,提供了十分新奇的思路。cifar数据集也可以复现精度。但是在CUB数据集上,我按照文中附录的设置,模型无法收敛,不知道是哪里设置的问题。请问可以提供CUB数据集的训练设置yaml文件,以及数据集加载以及预处理相关代码吗,非常感谢。

ETF_classifier

class ETF_Classifier(nn.Module):
def init(self, feat_in, num_classes, fix_bn=False, LWS=False, reg_ETF=False):
super(ETF_Classifier, self).init()
P = self.generate_random_orthogonal_matrix(feat_in, num_classes)
I = torch.eye(num_classes)
one = torch.ones(num_classes, num_classes)
M = np.sqrt(num_classes / (num_classes-1)) * torch.matmul(P, I-((1/num_classes) * one))
self.ori_M = M.cuda()
# 无用
self.LWS = LWS
self.reg_ETF = reg_ETF

if LWS:

self.learned_norm = nn.Parameter(torch.ones(1, num_classes))

self.alpha = nn.Parameter(1e-3 * torch.randn(1, num_classes).cuda())

self.learned_norm = (F.softmax(self.alpha, dim=-1) * num_classes)

else:

self.learned_norm = torch.ones(1, num_classes).cuda()

    self.BN_H = nn.BatchNorm1d(feat_in)
    if fix_bn:
        self.BN_H.weight.requires_grad = False
        self.BN_H.bias.requires_grad = False


def generate_random_orthogonal_matrix(self, feat_in, num_classes):
    a = np.random.random(size=(feat_in, num_classes))
    P, _ = np.linalg.qr(a)
    P = torch.tensor(P).float()
    assert torch.allclose(torch.matmul(P.T, P), torch.eye(num_classes), atol=1e-07), torch.max(torch.abs(torch.matmul(P.T, P) - torch.eye(num_classes)))
    return P

def forward(self, x):
    x = self.BN_H(x)
    return x

请问这里的self.ori_M也没用上啊,foward只用self.BN_H,而self.BN_H是一个标准化层

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.