Code Monkey home page Code Monkey logo

optexp's Introduction

optexp's People

Contributors

fkunstner avatar

Watchers

 avatar  avatar

optexp's Issues

Training loop

We should simplify the training loop to work only in terms of iterations internally, but keep track of the epoch counter for logging purposes. Having a single training-loop using lightning would make things much simpler.

Wandb controls

Need new environment variable / CLI ooverride for "wandb upload at the end".

Current warning looks wrong

2024-07-07 13:02:27 WARNING [config.py:79] Wandb autosync not specified. Defaults not syncing.To enable autosync, set the OPTEXP_WANDB_ENABLED to true.

Reducing dependencies

The current list of non-pytorch dependencies is long.
Can we remove some of them?

  • Can we remove the dependencies on numpy, scipy, and matplotlib if we remove the plotting code from this library and treat it separately,
    • We can remove the dependency on scipy and matplotlib
  • Can we remove the dependency on pyarrow? Can't remember what it's used for. Maybe for storing the experiment files locally when downloaded?
    • No, we can't, need dataframes to store the data
  • Can we remove the dependency on pandas if we don't deal with plotting code? Might be necessary to return something useful when loading experiment results
    • No, we can't, need dataframes to store the data
  • Can we remove the dependency on vit_pytorch by just porting the model we actually want? Can also remove this experiment for now.
    • done

Current deps

# Main requirements for the library
tqdm>=4.65.0
wandb>=0.15.3
lightning>=2.1.3
pandas>=1.5.3
pyarrow>=13.0.0

Starting dependencies

# Main requirements for the library
numpy>=1.24.2
scipy>=1.10.1
pandas>=1.5.3
tqdm>=4.65.0
wandb>=0.15.3
pyarrow>=13.0.0
pandas
wandb>=0.15.3
matplotlib>=3.6.1
vit_pytorch>=1.6.5
lightning>=2.1.3

Training with the CrossEntropy loss on a regression dataset fails silently

Example configuration below. It will run, but the result is not meaningful.
I'm surprised it doesn't hit a pytorch type error - doesn't CrossEntropy expect y to be long?

Anything we can check for silly mistakes like this?

from optexp.cli import cli
from optexp.datasets import DummyRegression
from optexp.experiment import Experiment
from optexp.hardwareconfig.strict_manual import StrictManualConfig
from optexp.metrics.metrics import Accuracy, CrossEntropy
from optexp.models import Linear
from optexp.optim.sgd import SGD
from optexp.problem import Problem
from optexp.runner.slurm.slurm_config import SlurmConfig

experiments = [
    Experiment(
        optim=SGD(lr=lr, momentum=momentum),
        problem=Problem(
            dataset=DummyRegression(),
            model=Linear(),
            lossfunc=CrossEntropy(),
            metrics=[CrossEntropy(), Accuracy()],
            batch_size=100,
        ),
        group="testing",
        eval_every=1,
        seed=0,
        steps=10,
        hardware_config=StrictManualConfig(
            num_devices=1,
            micro_batch_size=10,
            eval_micro_batch_size=10,
            device="cpu",
        ),
    )
    for lr in [10**-2, 10**-1, 10**0]
    for momentum in [0, 0.9]
]

SLURM_CONFIG = SlurmConfig(hours=1, gb_ram=8, n_cpus=1, n_gpus=1, gpu=True)
if __name__ == "__main__":
    cli(experiments, slurm_config=SLURM_CONFIG)

Long term/dream features

Automatic batch size selection

Checkpointing

This one is more complex.

Step 1. Figure out how to achieve a reproducible dataloader if we don't use an epoch-based system.

Introduce metrics

Metrics are currently hardcoded in the problem class.

It would be cleaner to pass a list of metrics to the experiment class to decide what to log.

For this we need an interface to define how metrics work.

There are a few corner cases that are not clear how to handle yet, in how to specify

  • Inputs Different metrics need different inputs, eg
    • L2 norm of the weights just needs the model
    • "Live" Accuracy needs the predictions and targets for the current batch
    • Full accuracy needs the predictions and targets for the entire dataset
      • Some need the training set, some need the validation set
  • Evaluation frequency Different metrics need different evaluation frequencies
    • Live metrics can be run each iteration
    • Entire dataset metrics should not

Clean up datasets

The dataset directory is a bit spaghetti, with load, download, dataloader and other niceties spread across multiple files for each dataset.

It would be nice to have a single file/class for each dataset with a unified interface that specifies how to

  • download the raw data and put it in the workspace folder (or -- if a synthetic dataset -- create the data)
  • load the dataset
  • create dataloaders

Common features (eg. how to handle image data) can be split into separate functions instead.

Step 1 could be to remove all the existing fluff and make a clean version of MNIST, PTB and a linear regression dataset (eg. some LIBSVM dataset like abalone) to make sure the interface works for different data types.

  • Remove unused datasets
  • Make a clean interface for the dataset and implement
    • MNIST
    • ImageNet
    • PTB
    • WT2 & WT103

Rework Problem class

The problem class currently defines some pieces of the training loop.

This is a but much, and some of it could be moved out of it and into the training loop.

The biggest problem with the Problem class is the hard-coded definition of the metrics, which we can get around if we introduce metrics separately as in #4. If we do this we can get rid of 90% of the code in this folder.

But I'd still want to keep the concept of a Problem around.
From the perspective of evaluating optimizers, the Problem is what we'd like to keep constant and try many optimizers against it.
Having a consistent naming scheme and a set of problem would be nice for this,
and we can't really get there by taking the cross product of datasets and models because some models only work with some datasets and don't make sense otherwise. Problem seems to be a nice abstraction for that.

And that abstraction might actually be necessary if we need to introduce more complicated "problems", say training a GAN,
that would have a different way to compute the loss function that is technically different from just loss(m(x),y)

Dataset and Dataloader - Batch-size and rounding-related parameters

What would be a reasonable way to handle the parameters related to the batch size and order of minibatches?

batch_size
micro_batch_size
shuffle
drop_last
sampler

Complete list of requirements that is not feasible to hit:

  • Make it easy to run full batch experiments
  • Make it easy to compare sampling IID and shuffling
  • Stay close to Pytorch behavior defaults
  • Reproducible behavior across micro batch size so its selection can be automated
  • Fit all the above with default parameters

The drop_last problem

Suppose we have N = number_of_samples, B = batch_size, and for simplicity batch_size == micro_batch_size.

The thing we want to avoid is having a step with a sample size smaller than B.

I can see at least two ways to handle drop_last;

    1. Use only a subset of the data at each epoch, miss some samples at each epoch (Pytorch default).
    1. Truncate the dataset at the start and use all the data at each epoch.

Both miss on N mod B < B samples, but (1) misses them randomly at every while (2) doesn't miss any but on a reduced dataset.

Additional difficulty 1: Interaction with multiprocessing

Using multiprocessing, drop_last is applied per dataloader on each worker node.
With W workers, we will miss N mod B W < BW samples

Additional difficulty 2: Interaction with micro batch size /!\

The dataloader will iterate over blocks of micro_batch_size, not batch_size.
But if the dataloader contains less than B samples less, we still shouldn't take a step, because it would be incomplete?
This cannot be implemented with the default pytorch dataloader and drop_last, which will drop when the remaining samples is less than microB.

With multiprocessingg, we should stop when the local dataloader contains less than B/W samples.

But, if we leave drop_last=False, then we don't actually need to have that B mod microB == 0,
because we can accomodate any batch size. Although we might have to randomly drop samples..?
(Maybe too much in the weeds)

Additional difficulty 3: Consistency of the "problem" across seeds /!\

  • If we do 1.
    • and disable shuffling: we never get to optimize a small subset of the dataset.
    • and don't disable shuffling: we don't run in full batch but in almost full batch.
  • If we do 2.
    • and use a different seed for the dataset: we have different datasets will lead to different problems every time.
    • and use the same for the dataset: we have the same dataset every time, but we have to make sure the validation dataloader is consistent with that.

Additional difficulty 4: Difference between training and evaluation /!\

We don't need to drop_last on the evaluation datasets if we're only using them to go through

Proposed solution

?

Sanity check:

  • How does it work if we want to run in full batch?
    • Ideally we would just have to set B = N, instead of having to set a separate full_batch mode
      • How do we know what N is?
        • What if it changes as a function of the batch size?
          • It doesn't (?), in the
      • What if N is a prime number and there is no way to set a reasonable micro batch size?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.