Code Monkey home page Code Monkey logo

Comments (18)

williamFalcon avatar williamFalcon commented on May 3, 2024

Anyone interested in implementing this?

Will need to make sure it's supported with distributed wrappers. Specifically, the DistributedDataset wrapper.

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

Another option would be to just concat the datasets with torch.utils.data.ConcatDataset.
It's kind of a quick fix but then I'm pretty sure that then the existing DistributedDataset wrapper should handle it the same as with having one dataset.

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

@sidhanthholalkere yeah, actually that might be a better option.
I like deferring this stuff to PyTorch.

So to use 2 datasets the user would (according to this issue):

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

train_loader = torch.utils.data.DataLoader(
             ConcatDataset(
                 datasets.ImageFolder(traindir_A),
                 datasets.ImageFolder(traindir_B)
             ),
             batch_size=args.batch_size, shuffle=True,
             num_workers=args.workers, pin_memory=True)

So looks like Lightning has to do nothing here except maybe add a documentation block to help users looking to do this?

Under the Validation step section:

- [Validation with multiple datasets][link to docs with details]

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

My issue with that method is that some samples in the larger dataset will be left out. The only benefit would be that you get to validate a batch from both datasets at the same time, but in my experience, only the full validation loss(on whole validation set) matters(not per batch).
Now if the user wants separate losses for each set, i think the best option would be:

for dataset in datasets:
    for batch in dataset:
        model.validation_step(batch_a, batch_nb, dataset_index)

If they just want the validation on both sets combined, I still think that using the built in torch.utils.data.ConcatDataset here is better because it handles some error checking and would essentially do the same thing as option B and C

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

that looks great. @sidhanthholalkere Want to give it a shot and submit a PR?

File that need to be changed:

  • trainer
  • add support for returning either a single or multiple dataloaders from val_dataloader
  • tests
  • testModels (there are 2. 1 should return a single val_dataloader, the other should return 2. A test should be written to make sure both dataloaders are called and used correctly.

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

I've started implementing this locally. I originally started by having validate() take in dataset_index so validation_step() could have access to dataset_index so the user can name the outputs accordingly, ie:
return {'val_loss_{}'.format(dataset_index}: whatever}
Looking back, it feels like this adds unnecessary complexity because the user has to decide what to do with dataset index.

Now, I'm trying a new approach where validate itself can take in the list of val_dataloaders (instead of previously having to enumerate through them and passing dataset_index) and then just append whatever dataset_index to each of the result keys, ie:

output = {key+str(dataset_index) : value for key, value in output.items()} if len(dataloader) > 1 else outputs

With this new method, there's no need to enumerate through val.dataloader when calling validate AND the user doesn't need to handle dataset_index.
What do you think of this new method?

Now for writing tests, how should I write the tests?
Do you want me to just check if model.nb_val_batches is correct and if model.validate() works(since that is all i'm changing, at least in the new method)?
Also, for creating the test with two val_dataloaders, should I just use the default get_model() and override val_dataloader() to

@ptl.data_loader
def val_dataloader(self):
    print('val data loader called')
    return [self.__dataloader(train=False) for i in range(2)]

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

Interesting suggestion but I like the version with the dataset index, just pass it after batch_nb. This way it’s intuitive, users don’t have to read docs, and it remains fully flexible.

So, I propose we do something along the lines of:

for dataset in val_datasets:
    for batch in dataset: 
       out = model.validation_step(batch, batch_nb, dataset_i)

It’s also backwards compatible and what a researcher would do if they had to implement it on their own.

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

I've added support for multiple val_dataloaders on my fork.
Is there a specific way you want me to write tests?
Also, unfortunately I won't be able to run most of the other tests locally as I don't have a multiple gpu machine/apex working for now.

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

awesome contribution!

for tests, can you add multiple val dataloaders to exampleModel used in tests? then modify what’s returned to be a accuracy and val loss indexed by the dataset.

add a separate test to trainer where it inits dataloaders, then check that the loaders are correct (we don’t have such test yet).

make sure to also check that all val dataloaders are being wrapped in distributed dataset (there’s a warning for that).

i can run in gpus once you submit.

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

when you say exampleModel, do you mean the template model(lightning_module_template.py)?
Also, how would you want me to init multiple dataloaders, I could make a child of LightningTemplateModel where val_dataloader returns 2 dataloaders instead of one.

For the warning about val_dataloaders being wrapped, should I write another exception in trainer.py similar to the one that checks if ddp is used & if the tng_dataloader is a DistributedSampler.

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

@sidhanthholalkere

  1. this is exampleModel https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/testing/lm_test_module.py.

  2. return 2 dataloaders from the module above. That way some of the tests use a different model which returns 1 loader and other tests use this model which returns 2.
    return [ds1, ds2]

  3. just use

import warnings

warnings.warn('something to warn about')

4.in fact, while you're there, could you remove the exception about distsampler and turn it into a warning? that would solve: #81.

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

Ok, I've made the changes, here is a summary of what I've done:

  1. Multiple val_dataloader support in trainer.py
  2. Added 2 val_dataloaders for lm_test_module.py(its just the same one twice
  3. Added an output to validation_step (if batch_i % 4 == 0) that has the losses/accuracies indexed by dataset
  4. Warning for if val_dataloaders are not DistributedSamplers and ddp is selected
  5. Test fit a model with multiple val_dataloaders and check if the length of the trainers val_dataloader is 2(not sure about this test)

Let me know if anything should be changed(its on my fork if you want to check) before I submit a PR
Quick nitpick, in lm_test_module.py and in validation_step(), the acc and loss are named val_acc and loss_val, could I change it so the naming is consistent?
Also, in your test_models.py, multiple comments say "traning complete" so I can fix that in another PR

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

@sidhanthholalkere awesome! probably easier to make edits and comments on the PR itself!

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

@sidhanthholalkere 4 was taken care off in a different PR by someone yesterday

from lightning.

lorenzoFabbri avatar lorenzoFabbri commented on May 3, 2024

Any update?
I actually need to use two data loaders for validation and take the mean of the logits. I tried to return a list of loaders for val_dataloader but it does not work:

TypeError: 'DataLoader' object is not subscriptable.

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

not live yet. it's on @sidhanthholalkere's branch. waiting on a PR to merge

from lightning.

sholalkere avatar sholalkere commented on May 3, 2024

Fixing some errors, should be finished soon

from lightning.

williamFalcon avatar williamFalcon commented on May 3, 2024

@lorenzoFabbri @sidhanthholalkere merged! was not super trivial to verify haha

from lightning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.