Code Monkey home page Code Monkey logo

vat-pytorch's People

Contributors

lyakaap avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

vat-pytorch's Issues

Test Accuracy

Hi,
Thanks for sharing. Are the hyper-parameters set optimally? I get around 50% test accuracy on the CIFAR10 dataset, which I assume is very low.

Thanks

Potential bug when calculating KL div

In lines 43 and 52, shouldn't there be F.softmax(pred.detach()) ? Since the KLDivLoss() expects probabilities and not scores in the second argument?

Also while doing the forward pass : pred_hat = model(X + r_adv), shouldn't there be a detach() on r_adv, as advised in the paper : "By the definition, r˜v-adv depends on θ. Our numerical experiments, however, indicates that ∇θrv-adv is quite volatile with respect to θ, and we could not make effective regularization when we used the numerical evaluation of ∇θr˜v-adv in ∇θLDS( g x(n), θ). We have therefore followed the work of Goodfellow et al. (2015) and ignored the derivative of r˜v-adv with respect to θ."

Misplaced parameter in README example

In the example in the README, the the model is passed to VATLoss.init, but it should be passed to VATLoss.forward() like,

VATLoss(model, xi=10.0, eps=1.0, ip=1)
vat_loss(data)

but it should be

VATLoss(xi=10.0, eps=1.0, ip=1)
vat_loss(model, data)

thanks

Some code issues

  1. In VAT-pytorch/vat.py/line7-16, running_mean and running_var in BatchNorm layers will keep updating if you only set track_running_stats to False. Setting the layer into eval mode would be better;

  2. VAT-pytorch/train_CIFAR10.py/line53-58, both labeled and unlabeled samples are used to compute the regularization
    loss in the original paper , and they are in different batch size(64, 128 respectively).

Single loop training method

Is there a reason for using a single loop training model compared to a

   for epoch in range(n_epochs):
       for iter in range(n_iters):
           train.....

I understand that it may not be trivial to get the number of iterations when there are two data loaders with different sizes of samples. But any explanation is really appreciated.

random unit tensor strictly positive

Hello,

When you "prepare random unit tensor" in the VAT code, you use this :

d = torch.rand(x.shape).to( torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
d = _l2_normalize(d)

It mean that d is strictly positive. It's intended?
Sometimes we need to have a partially negative perturbation.

Why you didn't use something like the following code?
d = torch.rand(x.shape) - 0.5
d = d.to('cuda' if torch.cuda.is_available() else 'cpu')
d = _l2_normalize(d)

Regression & SSDKL variant

First of all, thank you for such a great work. In the below repository, they used VAT for regression and also proposed another method called SSDKL, which they claim to outperform VAT in that setting.

https://github.com/ermongroup/ssdkl

Unfortunately, it is not in PyTorch. Hence, I would be glad if you can include a regression example for VAT in this repository; and also include SSDKL technique as well if that is possible. Thank you again!

What's the point of _disable_tracking_bn_stats()?

I don't understand what _disable_tracking_bn_stats() is trying to do? I don't think the network itself has track_running_stats attribute for the condition to be met but, if any, you would have to go to the batchnorm.

I think you are trying to fix running_mean and running_var while getting the VAT Loss. But I don't think it works by changing track_running_stats, either.

Hyperparameters xi

In VAT paper the hyperparameter xi is set to 1e-6 (only a small value coincides with taylor expansion), why in this repo it is 10?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.