Main links:
optexp's Introduction
optexp's People
optexp's Issues
Make test to check that iterations with multiple devices are consistent with single device
Make test to check that seed initialization is repeatable
Make toy dataset and toy models for tests
Training loop
We should simplify the training loop to work only in terms of iterations internally, but keep track of the epoch counter for logging purposes. Having a single training-loop using lightning would make things much simpler.
Wandb controls
Need new environment variable / CLI ooverride for "wandb upload at the end".
Current warning looks wrong
2024-07-07 13:02:27 WARNING [config.py:79] Wandb autosync not specified. Defaults not syncing.To enable autosync, set the OPTEXP_WANDB_ENABLED to true.
Reducing dependencies
The current list of non-pytorch dependencies is long.
Can we remove some of them?
- Can we remove the dependencies on
numpy
,scipy
, andmatplotlib
if we remove the plotting code from this library and treat it separately,- We can remove the dependency on
scipy
andmatplotlib
- We can remove the dependency on
- Can we remove the dependency on
pyarrow
? Can't remember what it's used for. Maybe for storing the experiment files locally when downloaded?- No, we can't, need dataframes to store the data
- Can we remove the dependency on
pandas
if we don't deal with plotting code? Might be necessary to return something useful when loading experiment results- No, we can't, need dataframes to store the data
- Can we remove the dependency on
vit_pytorch
by just porting the model we actually want? Can also remove this experiment for now.- done
Current deps
# Main requirements for the library
tqdm>=4.65.0
wandb>=0.15.3
lightning>=2.1.3
pandas>=1.5.3
pyarrow>=13.0.0
Starting dependencies
# Main requirements for the library
numpy>=1.24.2
scipy>=1.10.1
pandas>=1.5.3
tqdm>=4.65.0
wandb>=0.15.3
pyarrow>=13.0.0
pandas
wandb>=0.15.3
matplotlib>=3.6.1
vit_pytorch>=1.6.5
lightning>=2.1.3
`check` and `download` are broken
Probably due to the changes in experiment ids.
Can be reproduced with examples/small_gridsearch.py
Training with the CrossEntropy loss on a regression dataset fails silently
Example configuration below. It will run, but the result is not meaningful.
I'm surprised it doesn't hit a pytorch type error - doesn't CrossEntropy expect y
to be long
?
Anything we can check for silly mistakes like this?
from optexp.cli import cli
from optexp.datasets import DummyRegression
from optexp.experiment import Experiment
from optexp.hardwareconfig.strict_manual import StrictManualConfig
from optexp.metrics.metrics import Accuracy, CrossEntropy
from optexp.models import Linear
from optexp.optim.sgd import SGD
from optexp.problem import Problem
from optexp.runner.slurm.slurm_config import SlurmConfig
experiments = [
Experiment(
optim=SGD(lr=lr, momentum=momentum),
problem=Problem(
dataset=DummyRegression(),
model=Linear(),
lossfunc=CrossEntropy(),
metrics=[CrossEntropy(), Accuracy()],
batch_size=100,
),
group="testing",
eval_every=1,
seed=0,
steps=10,
hardware_config=StrictManualConfig(
num_devices=1,
micro_batch_size=10,
eval_micro_batch_size=10,
device="cpu",
),
)
for lr in [10**-2, 10**-1, 10**0]
for momentum in [0, 0.9]
]
SLURM_CONFIG = SlurmConfig(hours=1, gb_ram=8, n_cpus=1, n_gpus=1, gpu=True)
if __name__ == "__main__":
cli(experiments, slurm_config=SLURM_CONFIG)
Long term/dream features
Automatic batch size selection
Checkpointing
This one is more complex.
Step 1. Figure out how to achieve a reproducible dataloader if we don't use an epoch-based system.
Introduce metrics
Metrics are currently hardcoded in the problem class.
It would be cleaner to pass a list of metrics to the experiment class to decide what to log.
For this we need an interface to define how metrics work.
There are a few corner cases that are not clear how to handle yet, in how to specify
- Inputs Different metrics need different inputs, eg
- L2 norm of the weights just needs the model
- "Live" Accuracy needs the predictions and targets for the current batch
- Full accuracy needs the predictions and targets for the entire dataset
- Some need the training set, some need the validation set
- Evaluation frequency Different metrics need different evaluation frequencies
- Live metrics can be run each iteration
- Entire dataset metrics should not
Need some form of logging during training
Clean up datasets
The dataset directory is a bit spaghetti, with load
, download
, dataloader
and other niceties spread across multiple files for each dataset.
It would be nice to have a single file/class for each dataset with a unified interface that specifies how to
- download the raw data and put it in the workspace folder (or -- if a synthetic dataset -- create the data)
- load the dataset
- create dataloaders
Common features (eg. how to handle image data) can be split into separate functions instead.
Step 1 could be to remove all the existing fluff and make a clean version of MNIST
, PTB
and a linear regression dataset (eg. some LIBSVM dataset like abalone) to make sure the interface works for different data types.
- Remove unused datasets
- Make a clean interface for the dataset and implement
- MNIST
- ImageNet
- PTB
- WT2 & WT103
Set up git repo
- Move code from summer23
- ReadTheDocs integration
- Github actions
- black
- mypy
- pylint
- Set up tests
Rework Problem class
The problem class currently defines some pieces of the training loop.
This is a but much, and some of it could be moved out of it and into the training loop.
The biggest problem with the Problem class is the hard-coded definition of the metrics, which we can get around if we introduce metrics separately as in #4. If we do this we can get rid of 90% of the code in this folder.
But I'd still want to keep the concept of a Problem
around.
From the perspective of evaluating optimizers, the Problem
is what we'd like to keep constant and try many optimizers against it.
Having a consistent naming scheme and a set of problem would be nice for this,
and we can't really get there by taking the cross product of datasets and models because some models only work with some datasets and don't make sense otherwise. Problem
seems to be a nice abstraction for that.
And that abstraction might actually be necessary if we need to introduce more complicated "problems", say training a GAN,
that would have a different way to compute the loss function that is technically different from just loss(m(x),y)
Add `num_workers` to hardwareconfig
Dataset and Dataloader - Batch-size and rounding-related parameters
What would be a reasonable way to handle the parameters related to the batch size and order of minibatches?
batch_size
micro_batch_size
shuffle
drop_last
sampler
Complete list of requirements that is not feasible to hit:
- Make it easy to run full batch experiments
- Make it easy to compare sampling IID and shuffling
- Stay close to Pytorch behavior defaults
- Reproducible behavior across micro batch size so its selection can be automated
- Fit all the above with default parameters
The drop_last
problem
Suppose we have N = number_of_samples
, B = batch_size
, and for simplicity batch_size == micro_batch_size
.
The thing we want to avoid is having a step with a sample size smaller than B
.
I can see at least two ways to handle drop_last
;
-
- Use only a subset of the data at each epoch, miss some samples at each epoch (Pytorch default).
-
- Truncate the dataset at the start and use all the data at each epoch.
Both miss on N mod B < B
samples, but (1) misses them randomly at every while (2) doesn't miss any but on a reduced dataset.
Additional difficulty 1: Interaction with multiprocessing
Using multiprocessing, drop_last
is applied per dataloader on each worker node.
With W
workers, we will miss N mod B W < BW
samples
Additional difficulty 2: Interaction with micro batch size /!\
The dataloader will iterate over blocks of micro_batch_size
, not batch_size
.
But if the dataloader contains less than B
samples less, we still shouldn't take a step, because it would be incomplete?
This cannot be implemented with the default pytorch dataloader and drop_last
, which will drop when the remaining samples is less than microB
.
With multiprocessingg, we should stop when the local dataloader contains less than B/W
samples.
But, if we leave drop_last=False
, then we don't actually need to have that B mod microB == 0
,
because we can accomodate any batch size. Although we might have to randomly drop samples..?
(Maybe too much in the weeds)
Additional difficulty 3: Consistency of the "problem" across seeds /!\
- If we do 1.
- and disable shuffling: we never get to optimize a small subset of the dataset.
- and don't disable shuffling: we don't run in full batch but in almost full batch.
- If we do 2.
- and use a different seed for the dataset: we have different datasets will lead to different problems every time.
- and use the same for the dataset: we have the same dataset every time, but we have to make sure the validation dataloader is consistent with that.
Additional difficulty 4: Difference between training and evaluation /!\
We don't need to drop_last
on the evaluation datasets if we're only using them to go through
Proposed solution
?
Sanity check:
- How does it work if we want to run in full batch?
- Ideally we would just have to set
B = N
, instead of having to set a separatefull_batch
mode- How do we know what
N
is?- What if it changes as a function of the batch size?
- It doesn't (?), in the
- What if it changes as a function of the batch size?
- What if
N
is a prime number and there is no way to set a reasonable micro batch size?
- How do we know what
- Ideally we would just have to set
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.