Code Monkey home page Code Monkey logo

pytorch_dann's Introduction

Pytoch_DANN

This is a implementation of Domain-Adversarial Training of Neural Networks
with pytorch. This paper introduced a simple and effective method for accompli-
shing domian adaptation with SGD with a GRL(Gradient Reveral Layer). According
to this paper, domain classifier is used to decrease the H-divergence between
source domain distribution and target domain distribution. For the tensorflow
version, you can see tf-dann.

requirements

python3.6.2
pip install -r requirements.txt

Data

In this work, MNIST and MNIST_M datasets are used in experiments. MNIST dataset
can be downloaded with torchvision.datasets. MINIST_M dataset can be downloa-
ded at Yaroslav Ganin's homepage. Then you can extract the file to your data dire-
ctory and run the preprocess.py to make the directory able to be used with
torchvision.datasets.ImageFolder:

python preprocess.py

Experiments

You can run main.py to implements the MNSIT experiments for the paper with the
similar model and same paramenters.The paper's results and this work's results a-
re as follows:

Method Target Acc(paper) Target Acc(this work)
Source Only 0.5225 0.5189
DANN 0.7666 0.7600

Experiment on SVHN->MNIST is added in this project, but some bugs are not fixed.
The accuracies of source and target domains are not good at the same time.

Experiment on SynDig->SVHN is added.

pytorch_dann's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pytorch_dann's Issues

Evaluation does not suppress dropout, which will affect performance

Hi there, thanks for the great repo! I would like to point out that I think the dropout function should be initialized at init in the in Class_classifier, so that. Your implementation uses dropout inplace.

logits = self.fc2(F.dropout(logits))

By adding self.dropout = nn.Dropout() during the initialization, and replace this term with

logits = self.fc2(logits)
logits = self.dropout(logits)

I was able to obtain a 2~3% performance gain on the tasks using the same model checkpoint and inputs.
Could you help me verify if my understanding is correct?

There is something wrong

When I run the main.py, the error comes :RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

Version issue

I tried to run the code on py2.7 but it keeps giving me an error -
../models/models.py", line 68, in forward
input = GradReverse.grad_reverse(input, constant)
TypeError: unbound method grad_reverse() must be called with GradReverse instance as first argument (got Variable instance instead)
I tried with some more combinations of versions (py-3.5, torch 0.4.1, etc.), all gave some or the other error.
I suspect it is due to some version issue of some package. It will be very much helpful if you can upload environment.yml file or version details in a req.txt file.

MNIST Dataloader error

There is a minor error in the train and test dataloaders for the MNIST dataset. The transform for normalizing data values receives mean and standard deviation for three channels whereas the MNIST images only have a single channel. We can just fix this by changing the values in the transforms such as:

if dataset == 'MNIST':
        transform = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean= params.dataset_mean[0], std= params.dataset_std[0])
                                        ])

This way only one channel is transformed as it should.

Result has huge difference

Hi,
The result of my code has huge difference with yours. At the epoch 99, my result is:

Source Accuracy: 9839/10000 (98.0000%)
Target Accuracy: 5958/9001 (66.0000%)
Domain Accuracy: 10646/19001 (56.0000%)

May I ask how much epoch did you run to get that result please?
Thank you in advance.

Loss function

Hi,

I'm not sure if this is an issue, or that's my misunderstanding. You calculated loss as:

loss = class_loss + params.theta * domain_loss

shouldn't it be:

loss = class_loss - params.theta * domain_loss

according to the formula (9) of the paper?

Error when starting training

I've been struggling too much days with this error, maybe it's really simple but i can't figure out the reason yet. The error is the following:

Traceback (most recent call last):
File "main.py", line 205, in
main(parse_arguments(sys.argv[1:]))
File "main.py", line 169, in main
src_train_dataloader, tgt_train_dataloader, optimizer, epoch)
File "/Volumes/Elements/DANN/pytorch_DANN-master/train/train.py", line 36, in train
for batch_idx, (sdata, tdata) in enumerate(zip(source_dataloader, target_dataloader)):
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 615, in next
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 615, in
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py", line 95, in getitem
img = self.transform(img)
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in call
img = t(img)
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 163, in call
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/Volumes/Elements/Apps/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
IndexError: too many indices for tensor of dimension 0

I'd appreciate a lot your help.

Can this be used in Regression Problem?

Hi @tengerye @CuthbertCai! Thank you for sharing the code.

I was wondering if this particular domain adaptation technique can be used in a regression problem. Let's say if I want to use this domain adaptation technique for a regression problem where the loss function for the main task is MAE or MSE. I am a bit confused about how to actually backpropagate the loss in such a case. Since in the classification problem, the total loss is a summation (according to the code) of the class loss and domain loss and as both of them are negative log-likelihood loss there is no issue while adding them. However, if the loss for the actual task and the loss for the domain classification task are different what should we do in that case? Your insight would be really helpful! Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.