Code Monkey home page Code Monkey logo

pytorch-balanced-sampler's Introduction

PyTorch Balanced Sampler

PyTorch implementations of BatchSampler that under/over sample according to a chosen parameter alpha, in order to create a balanced training distribution.

./resources/sample-distributions-2015-data.png

Usage

SamplerFactory

The factory class constructs a pytorch BatchSampler to yield balanced samples from a training distribution.

from pytorch_balanced_sampler.sampler import SamplerFactory

# which sample indices belong to each of 4 classes
class_idxs = [
    [1, 2, 3, 4, 5],
    [6, 7, 8, 9, 10],
    [11, 12],
    [13, 14, 15, 16, 17, 18, 19, 20]
]

batch_sampler = SamplerFactory().get(
    class_idxs=class_idxs,
    batch_size=32,
    n_batches=250,
    alpha=0.5,
    kind='fixed'
)

dataset = Dataset( ... )
data_loader = DataLoader(dataset, batch_sampler=batch_sampler)

for data, target in data_loader:
    # nice balanced batches!
    ...

Class Balancing

Based on the choice of an alpha parameter in [0, 1] the sampler will adjust the sample distribution to be between true distribution (alpha = 0), and a uniform distribution (alpha = 1).

Overrepresented classes will be undersampled, and underrepresented classes oversampled. Here's an example from an imbalanced data distribution I was working with a while ago:

./resources/sample-distributions-2015-data.png

Fixed Batch Distributions

If you select kind='fixed', each batch generated will contain a consistent proportion of classes. Eg. if we have 5 classes, we might receive batches like:

Batch: 0
Classes: [1, 0, 0, 0, 2, 4, 0, 2, 0, 0, 3, 2, 1, 0, 2, 0, 0, 3, 0, 0, 4, 4, 0, 2, 1, 3, 3, 1, 2, 0, 0, 4]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Batch: 1
Classes: [4, 1, 1, 2, 0, 0, 0, 4, 2, 4, 0, 3, 1, 3, 0, 0, 3, 2, 0, 2, 4, 2, 0, 0, 2, 3, 0, 1, 0, 0, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Batch: 2
Classes: [0, 4, 0, 0, 0, 3, 3, 2, 0, 4, 2, 3, 0, 3, 2, 0, 0, 1, 2, 2, 0, 1, 0, 0, 4, 0, 2, 1, 1, 4, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Note that the class counts are the same for each batch.

Random Batch Distributions

If you don't want to fix the number of each class in each batch, you can select kind='random', which will use sampling with replacement. The samples will be weighted as to produce the target class distribution on average.

Authors

pytorch_balanced_sampler was written by Karl Hornlund.

pytorch-balanced-sampler's People

Contributors

jihang-zhang avatar khornlund avatar

Watchers

 avatar

Forkers

o-senpai-o

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.