Comments (18)
Anyone interested in implementing this?
Will need to make sure it's supported with distributed wrappers. Specifically, the DistributedDataset wrapper.
from lightning.
Another option would be to just concat the datasets with torch.utils.data.ConcatDataset
.
It's kind of a quick fix but then I'm pretty sure that then the existing DistributedDataset wrapper should handle it the same as with having one dataset.
from lightning.
@sidhanthholalkere yeah, actually that might be a better option.
I like deferring this stuff to PyTorch.
So to use 2 datasets the user would (according to this issue):
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
train_loader = torch.utils.data.DataLoader(
ConcatDataset(
datasets.ImageFolder(traindir_A),
datasets.ImageFolder(traindir_B)
),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
So looks like Lightning has to do nothing here except maybe add a documentation block to help users looking to do this?
Under the Validation step section:
- [Validation with multiple datasets][link to docs with details]
from lightning.
My issue with that method is that some samples in the larger dataset will be left out. The only benefit would be that you get to validate a batch from both datasets at the same time, but in my experience, only the full validation loss(on whole validation set) matters(not per batch).
Now if the user wants separate losses for each set, i think the best option would be:
for dataset in datasets:
for batch in dataset:
model.validation_step(batch_a, batch_nb, dataset_index)
If they just want the validation on both sets combined, I still think that using the built in torch.utils.data.ConcatDataset
here is better because it handles some error checking and would essentially do the same thing as option B and C
from lightning.
that looks great. @sidhanthholalkere Want to give it a shot and submit a PR?
File that need to be changed:
- trainer
- add support for returning either a single or multiple dataloaders from val_dataloader
- tests
- testModels (there are 2. 1 should return a single val_dataloader, the other should return 2. A test should be written to make sure both dataloaders are called and used correctly.
from lightning.
I've started implementing this locally. I originally started by having validate()
take in dataset_index
so validation_step()
could have access to dataset_index
so the user can name the outputs accordingly, ie:
return {'val_loss_{}'.format(dataset_index}: whatever}
Looking back, it feels like this adds unnecessary complexity because the user has to decide what to do with dataset index.
Now, I'm trying a new approach where validate itself can take in the list of val_dataloaders (instead of previously having to enumerate through them and passing dataset_index) and then just append whatever dataset_index to each of the result keys, ie:
output = {key+str(dataset_index) : value for key, value in output.items()} if len(dataloader) > 1 else outputs
With this new method, there's no need to enumerate through val.dataloader when calling validate AND the user doesn't need to handle dataset_index.
What do you think of this new method?
Now for writing tests, how should I write the tests?
Do you want me to just check if model.nb_val_batches
is correct and if model.validate()
works(since that is all i'm changing, at least in the new method)?
Also, for creating the test with two val_dataloaders, should I just use the default get_model()
and override val_dataloader()
to
@ptl.data_loader
def val_dataloader(self):
print('val data loader called')
return [self.__dataloader(train=False) for i in range(2)]
from lightning.
Interesting suggestion but I like the version with the dataset index, just pass it after batch_nb. This way it’s intuitive, users don’t have to read docs, and it remains fully flexible.
So, I propose we do something along the lines of:
for dataset in val_datasets:
for batch in dataset:
out = model.validation_step(batch, batch_nb, dataset_i)
It’s also backwards compatible and what a researcher would do if they had to implement it on their own.
from lightning.
I've added support for multiple val_dataloaders on my fork.
Is there a specific way you want me to write tests?
Also, unfortunately I won't be able to run most of the other tests locally as I don't have a multiple gpu machine/apex working for now.
from lightning.
awesome contribution!
for tests, can you add multiple val dataloaders to exampleModel used in tests? then modify what’s returned to be a accuracy and val loss indexed by the dataset.
add a separate test to trainer where it inits dataloaders, then check that the loaders are correct (we don’t have such test yet).
make sure to also check that all val dataloaders are being wrapped in distributed dataset (there’s a warning for that).
i can run in gpus once you submit.
from lightning.
when you say exampleModel, do you mean the template model(lightning_module_template.py)?
Also, how would you want me to init multiple dataloaders, I could make a child of LightningTemplateModel where val_dataloader returns 2 dataloaders instead of one.
For the warning about val_dataloaders being wrapped, should I write another exception in trainer.py similar to the one that checks if ddp is used & if the tng_dataloader is a DistributedSampler.
from lightning.
@sidhanthholalkere
-
this is exampleModel https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/testing/lm_test_module.py.
-
return 2 dataloaders from the module above. That way some of the tests use a different model which returns 1 loader and other tests use this model which returns 2.
return [ds1, ds2]
-
just use
import warnings
warnings.warn('something to warn about')
4.in fact, while you're there, could you remove the exception about distsampler and turn it into a warning? that would solve: #81.
from lightning.
Ok, I've made the changes, here is a summary of what I've done:
- Multiple
val_dataloader
support intrainer.py
- Added 2
val_dataloader
s forlm_test_module.py
(its just the same one twice - Added an output to validation_step (
if batch_i % 4 == 0
) that has the losses/accuracies indexed by dataset - Warning for if
val_dataloader
s are notDistributedSamplers
andddp
is selected - Test fit a model with multiple
val_dataloader
s and check if the length of the trainersval_dataloader
is 2(not sure about this test)
Let me know if anything should be changed(its on my fork if you want to check) before I submit a PR
Quick nitpick, in lm_test_module.py
and in validation_step()
, the acc and loss are named val_acc
and loss_val
, could I change it so the naming is consistent?
Also, in your test_models.py
, multiple comments say "traning complete" so I can fix that in another PR
from lightning.
@sidhanthholalkere awesome! probably easier to make edits and comments on the PR itself!
from lightning.
@sidhanthholalkere 4 was taken care off in a different PR by someone yesterday
from lightning.
Any update?
I actually need to use two data loaders for validation and take the mean of the logits. I tried to return a list of loaders for val_dataloader
but it does not work:
TypeError: 'DataLoader' object is not subscriptable
.
from lightning.
not live yet. it's on @sidhanthholalkere's branch. waiting on a PR to merge
from lightning.
Fixing some errors, should be finished soon
from lightning.
@lorenzoFabbri @sidhanthholalkere merged! was not super trivial to verify haha
from lightning.
Related Issues (20)
- Construct objects from yaml by classmethod
- FSDP Strategy checkpoint loading
- Current FSDPPrecision does not support custom scaler for 16-mixed precision
- Differentiate testing multiple sets/models when logging
- Issue in Manual optimisation, during self.manual_backward call HOT 1
- Existing metric keys not moved to device after LearningRateFinder
- Checkpoint every_n_steps reruns epoch on restore HOT 3
- Metrics logged by self.log and metric.compute() are different HOT 1
- Multi-node Training with DDP stuck at "Initialize distributed..." on SLURM cluster HOT 3
- Full validation after first microbatch when training after LearningRateFinder
- Add a warning when some of the modules are in eval mode before the training stage
- why pytorch-lightning doc say "Model-parallel training (FSDP and DeepSpeed)". I think there is something wrong. HOT 1
- AWS Trainium fails number of device validation when using more than 1 accelerator on the instances
- OnExceptionCheckpoint: training resumes if ckpt found, even if no ckpt_path provided
- TensorBoardLogger has the wrong epoch numbers much more than the fact
- How to incorporate vLLM in Lightning for LLM inference?
- WandbLogger `save_dir` and `dir` parameters do not work as expected.
- Loading large models with fabric, FSDP and empty_init=True does not work
- Unable to extract confusion matrix as a metric from trainer
- Torchmetrics Accuracy issue when dont shuffle test data. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning.