Code Monkey home page Code Monkey logo

domainbiasmitigation's People

Contributors

zywangcode avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

domainbiasmitigation's Issues

Unstable Domain Accuracy Performance

Hi,
When training the celeba_uniconf_adv config, the domain accuracy towards the end (~48-50 epochs) reaches 98% and routinely oscillates to 50%.
This seemed erratic. Is this expected behavior? Can you please help resolve?

Training and Test logs for reference:

Training epoch 48: [4351|5087], class loss:0.1915399134159088, domain loss: 0.6868267059326172, domain accuracy: 51.35241898414158
Training epoch 48: [4401|5087], class loss:0.16185306012630463, domain loss: 0.6825014352798462, domain accuracy: 51.327113156100886
Training epoch 48: [4451|5087], class loss:0.2070169895887375, domain loss: 0.6903568506240845, domain accuracy: 51.32273646371602
Training epoch 48: [4501|5087], class loss:0.1745804101228714, domain loss: 0.6876736283302307, domain accuracy: 51.31498555876472
Training epoch 48: [4551|5087], class loss:0.18190763890743256, domain loss: 0.6901058554649353, domain accuracy: 51.29435838277302
Training epoch 48: [4601|5087], class loss:0.20187050104141235, domain loss: 0.7073038220405579, domain accuracy: 51.30270593349272
Training epoch 48: [4651|5087], class loss:0.15669050812721252, domain loss: 0.7052844762802124, domain accuracy: 51.27929477531713
Training epoch 48: [4701|5087], class loss:0.16493597626686096, domain loss: 0.6976796388626099, domain accuracy: 51.28563071686875
Training epoch 48: [4751|5087], class loss:0.18540211021900177, domain loss: 0.6841596364974976, domain accuracy: 51.301699642180594
Training epoch 48: [4801|5087], class loss:0.16415195167064667, domain loss: 0.7053844332695007, domain accuracy: 51.285539470943554
Training epoch 48: [4851|5087], class loss:0.1703886091709137, domain loss: 0.7098171710968018, domain accuracy: 51.283884766027626
Training epoch 48: [4901|5087], class loss:0.17944636940956116, domain loss: 0.7142029404640198, domain accuracy: 51.2624974495001
Training epoch 48: [4951|5087], class loss:0.1634552776813507, domain loss: 0.6973509192466736, domain accuracy: 51.27120783680065
Training epoch 48: [5001|5087], class loss:0.18260404467582703, domain loss: 0.6861003041267395, domain accuracy: 51.269121175764845
Training epoch 48: [5051|5087], class loss:0.16221176087856293, domain loss: 0.6899145841598511, domain accuracy: 51.270169273411206
Finish training epoch 49, dev class loss: 0.1895156781550575, dev doamin loss: 0.6892390505511212, dev mAP: 0.7547963168677299,domain_accuracy: 55.12659183570745, time used: 0:27:10.755864
Training epoch 49: [1|5087], class loss:0.15959756076335907, domain loss: 0.7002204060554504, domain accuracy: 43.75
Training epoch 49: [51|5087], class loss:0.190892294049263, domain loss: 0.14717990159988403, domain accuracy: 83.08823529411765
Training epoch 49: [101|5087], class loss:0.17048077285289764, domain loss: 0.15136627852916718, domain accuracy: 88.49009900990099
Training epoch 49: [151|5087], class loss:0.18803445994853973, domain loss: 0.12117624282836914, domain accuracy: 90.93543046357615
Training epoch 49: [201|5087], class loss:0.1920115053653717, domain loss: 0.022530484944581985, domain accuracy: 92.24191542288557
Training epoch 49: [251|5087], class loss:0.1748587042093277, domain loss: 0.04192587360739708, domain accuracy: 92.94073705179282
Training epoch 49: [301|5087], class loss:0.16945448517799377, domain loss: 0.07877128571271896, domain accuracy: 93.51121262458472
Training epoch 49: [351|5087], class loss:0.1647031009197235, domain loss: 0.09114453941583633, domain accuracy: 93.90135327635328
Training epoch 49: [401|5087], class loss:0.16736702620983124, domain loss: 0.077174112200737, domain accuracy: 94.1708229426434
Training epoch 49: [451|5087], class loss:0.16356484591960907, domain loss: 0.031743574887514114, domain accuracy: 94.49140798226163
Training epoch 49: [501|5087], class loss:0.16792656481266022, domain loss: 0.049991361796855927, domain accuracy: 94.66067864271457
Training epoch 49: [551|5087], class loss:0.17482462525367737, domain loss: 0.11094588786363602, domain accuracy: 94.90131578947368
Training epoch 49: [601|5087], class loss:0.16284793615341187, domain loss: 0.014430172741413116, domain accuracy: 95.06031613976705
Training epoch 49: [651|5087], class loss:0.182270348072052, domain loss: 0.0598013773560524, domain accuracy: 95.21889400921658
Training epoch 49: [701|5087], class loss:0.17393416166305542, domain loss: 0.015329426154494286, domain accuracy: 95.3370185449358
Training epoch 49: [751|5087], class loss:0.17821387946605682, domain loss: 0.040832579135894775, domain accuracy: 95.48102529960053
Training epoch 49: [801|5087], class loss:0.1586090475320816, domain loss: 0.12714128196239471, domain accuracy: 95.61875780274657
Training epoch 49: [851|5087], class loss:0.16232921183109283, domain loss: 0.030300777405500412, domain accuracy: 95.71092831962397
Training epoch 49: [901|5087], class loss:0.17451979219913483, domain loss: 0.04219125211238861, domain accuracy: 95.83795782463929
Training epoch 49: [951|5087], class loss:0.17216108739376068, domain loss: 0.159927636384964, domain accuracy: 95.91219768664564
Training epoch 49: [1001|5087], class loss:0.16925860941410065, domain loss: 0.17401568591594696, domain accuracy: 96.00399600399601
Training epoch 49: [1051|5087], class loss:0.17023028433322906, domain loss: 0.1129363551735878, domain accuracy: 96.08111322549952
Training epoch 49: [1101|5087], class loss:0.17867332696914673, domain loss: 0.007533952593803406, domain accuracy: 96.15974114441417
Training epoch 49: [1151|5087], class loss:0.167255237698555, domain loss: 0.058074675500392914, domain accuracy: 96.21796264118159

Saved model files

Hi,
I found this paper very interesting and want work on this setting. Could you provide the saved models for the experiments.
Could you also provide the code for calculating the Bias Amplification.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

I am trying to run the celeba_gradproj_adv experiment but for some reason the gradient computation is not working properly. The classifier network is working but when the output of classifier is fed to the domain network which is just a single Linear layer, the model throws an error. I have done set_detect_anomaly(True) as well but can't get this code to work. Any help is appreciated

The error appears in _train function of models/celeba_gradproj_adv.py script:

# Update the main network
if self.epoch % self.training_ratio == 0:
    grad_from_class = torch.autograd.grad(class_loss, self.class_network.parameters(),
                                          retain_graph=True, allow_unused=True)
    grad_from_domain = torch.autograd.grad(domain_loss, self.class_network.parameters(),
                                           retain_graph=True, allow_unused=True)

setting torch.autograd.set_detect_anomaly(True) points to the upper block in _train function:

class_outputs, _ = self.class_network(images)
domain_outputs = self.domain_network(class_outputs)

Here is the snippet of the entire error message:

C:\Users\divya\miniconda3\envs\pytorch\lib\site-packages\torch\autograd\anomaly_mode.py:70: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  warnings.warn('Anomaly Detection has been enabled. '
Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "main.py", line 15, in <module>
    main(model, opt)
  File "main.py", line 9, in main
    model.train()
  File "D:\thesis\celeba_classifiers\models\celeba_gradproj_adv.py", line 178, in train
    self._train(self.train_loader)
  File "D:\thesis\celeba_classifiers\models\celeba_gradproj_adv.py", line 80, in _train
    domain_outputs = self.domain_network(class_outputs.clone())
  File "C:\Users\divya\miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\divya\miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\divya\miniconda3\envs\pytorch\lib\site-packages\torch\nn\functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
 (print_stack at ..\torch\csrc\autograd\python_anomaly_mode.cpp:60)
Traceback (most recent call last):
  File "main.py", line 15, in <module>
    main(model, opt)
  File "main.py", line 9, in main
    model.train()
  File "D:\thesis\celeba_classifiers\models\celeba_gradproj_adv.py", line 178, in train
    self._train(self.train_loader)
  File "D:\thesis\celeba_classifiers\models\celeba_gradproj_adv.py", line 101, in _train
    retain_graph=True, allow_unused=True)
  File "C:\Users\divya\miniconda3\envs\pytorch\lib\site-packages\torch\autograd\__init__.py", line 158, in grad
    inputs, allow_unused)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [39, 2]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

IndexError: while accessing the outputs in RBA_celeba.ipynb file

Hi All,
I am currently studying the paper and trying to run this code,
I've trained 'celeb_baseline', 'celeba_uniconf_adv' & 'celeba_domain_independent' models, but while running the RBA_celeba.ipynb file I've faced some issues in the test and dev output indexing ("IndexError: index 40 is out of bounds for axis 1 with size 39") in the below shells. can you please help me to fix this issue?

dev_outputs = dev['output'][:, subclass_idx + [item+39 for item in subclass_idx]]
test_outputs = test['output'][:, subclass_idx + [item+39 for item in subclass_idx]]

IndexError Traceback (most recent call last)
in
----> 1 dev_outputs = dev['output'][:, subclass_idx + [item+39 for item in subclass_idx]]
2 test_outputs = test['output'][:, subclass_idx + [item+39 for item in subclass_idx]]

IndexError: index 40 is out of bounds for axis 1 with size 39

Note :
Shapes of the Dev and Test results :
dev['output'].shape is (19867, 39)
test['output'].shape is (19962, 39) , but in the script we are accessing the 40th index. can you please help me to understand?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.