Code Monkey home page Code Monkey logo

dataset-distillation's People

Contributors

carmocca avatar ilia10000 avatar ssnl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

dataset-distillation's Issues

The weird thing of the backward function

Hello, I have cloned this repo and try to understand the code.

However, I have found some weird things in the Trainer class of the train_distilled_image.py

That Trainer class has a method named backward

def backward(self, model, rdata, rlabel, steps, saved_for_backward):
    l, params, gws = saved_for_backward
    # ....

The param and gws come from the forward function, but they have different length!

I have inserted the print code in the forward function like:

def forward(self, model, rdata, rlabel, steps):
    # .... code ....
    print(f"params's length is {len(params)}")
    print(f"gws's length is {len(gws)}")
    return ll, (ll, params, gws)

and run this command:

python main.py --mode distill_basic --dataset Cifar10 --arch AlexCifarNet  --distill_lr 0.001 --train_nets_type known_init --n_nets 1  --test_nets_type same_as_train

You will see the log:

params's length is 31
gws's length is 30

In the backward function, there is a zip method

for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):

zip(steps, params, gws) will return a shorter list. It ignores the final elements of params.

Question-1: Is that a mistake? Will that final element of the params affect the training?

In the backward function:

for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
            # hvp_in are the tensors we need gradients w.r.t. final L:
            #   lr (if learning)
            #   data
            #   ws (PRE-GD) (needed for next step)
            #
            # source of gradients can be from:
            #   gw, the gradient in this step, whose gradients come from:
            #     the POST-GD updated ws
            hvp_in = [w]
            if not state.freeze_data:
                hvp_in.append(data)
            hvp_in.append(lr)
            if not state.static_labels:
                hvp_in.append(label)
            dgw = dw.neg()  # gw is already weighted by lr, so simple negation
            hvp_grad = torch.autograd.grad(
                outputs=(gw,),
                inputs=hvp_in,
                grad_outputs=(dgw,),
                retain_graph=True
            )

In the first iteration:
Here, the first w is params[-2] and the hvp_grad contains the gradient of gw respect to param[-2].
However, the first dgw is the gradient of the loss respect to param[-1].

I cannot fully understand the meaning of the hvp_grad. (Newton method?)

Question-2: The logic of hvg_grad is hard to understand. Could you please explain the detail of that gradients?

Where is TextConvNet3?

In the text distillation section, you used the TextConvNet3 arch, but I can just find TextConvNet1 in your code.
May I ask where is the TextConvNet3?

One of the differentiated Tensors appears to not have been used in the graph

Hi,
When I want to do the soft label distillation on my own model and dataset with this command :
python main.py --dataset my_dataset --arch my_arch --epochs 1 --batch_size 4 --test_batch_size 1 --lr 0.01 --decay_epochs 1 --train
_nets_type loaded --test_nets_type same_as_train --device_id 0 --num_workers 3 --mode distill_basic --distill_epochs 1 --distill_steps 1 --distill_lr 0.01 --static_labels 0 --num_distill_classes 2

This error appears:
image

Note that when I run distillation with fixed label (hard label), it works.

After analysis, the problem seems to come from the calculation of hvp_grad in the script train_distilled_images.py. Indeed, this error is explained on this forum: https://discuss.pytorch.org/t/example-for-one-of-the-differentiated-tensors-appears-to-not-have-been-used-in-the-graph/58396 and seems to appear because the input is not the tensor used for the output.

For example, the following code will give the same error because -x is not considered as the same tensor as x :
x = torch.rand(10, requires_grad=True)
output = (2 * x).sum()
grad_x, = torch.autograd.grad(output, -x)
print(grad_x)

I hope you can help me solve this problem.

Thank you in advance

Questions on loss calculation

I have several questions about this.
First of all, I noticed that the calculation of Loss for a two-class model was treated differently. The model I am using being of two classes, errors were appearing (including memory limits). So I removed this condition to calculate only with the cross entropy F.cross_entropy(output, target) because a binary classification can be treated as multi-class classification:

image

The same thing was done at the different locations where the Loss was calculated.

I have some doubts about the loss values calculated in the train() function of train_distilled_images ... Indeed, the loss increases at each step instead of decreasing...
image

image

data_path

Hello, teacher, where is the path of glove.6B.zip that I downloaded locally

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Thank you for this great work. The idea of adding labels distillation to dataset distillation is very interesting.

When I try to run your code with the command
python main.py --mode distill_basic --dataset Cifar10 --arch AlexCifarNet --distill_lr 0.001
There is an error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

The detailed error message is as follows

Traceback (most recent call last):
File "/home/user/soft-label-DD/dataset-distillation/utils/multiprocessing.py", line 66, in join_all
self.join(i)
File "/home/user/soft-label-DD/dataset-distillation/utils/multiprocessing.py", line 53, in join
raise p.exception.reconstruct()
ValueError: Traceback (most recent call last):
File "/home/user/soft-label-DD/dataset-distillation/utils/multiprocessing.py", line 27, in run
mp.Process.run(self)
File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/user/soft-label-DD/dataset-distillation/utils/io.py", line 52, in _vis_results_fn
if std:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

It seems the error is due to some multiprocessing settings. Any idea on how to fix this error?

Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.