Code Monkey home page Code Monkey logo

ssljax's Issues

SWAV Prototype

Implement mm(z, C) as below

for x in loader: # load a batch x with B samples
    x_t = t(x) # t is a random augmentation
    x_s = s(x) # s is a another random augmentation
    z = model(cat(x_t, x_s)) # embeddings: 2BxD
    scores = mm(z, C) # prototype scores: 2BxK
    scores_t = scores[:B]
    scores_s = scores[B:]
    # compute assignments
    with torch.no_grad():
        q_t = sinkhorn(scores_t)
        q_s = sinkhorn(scores_s

SWAV Sinkhorn-Knopp Algorithm

Implement Sinkhorn-Knopp as in https://arxiv.org/abs/2006.09882

# Sinkhorn-Knopp
def sinkhorn(scores, eps=0.05, niters=3):
    Q = exp(scores / eps).T
    Q /= sum(Q)
    K, B = Q.shape
    u, r, c = zeros(K), ones(K) / K, ones(B) / B
    for _ in range(niters):
        u = sum(Q, dim=1)
        Q *= (r / u).unsqueeze(1)
        Q *= (c / sum(Q, dim=0)).unsqueeze(0)
        return (Q / sum(Q, dim=0, keepdim=True)).T

DINO multicrop preprocessing

DINO uses a multicrop scheme where the student processes "local views" 96x96 and compares the representation (by CCE) to "global views" 224x224.

In the following, we detail how we adapt the problem
in Eq. (2) to self-supervised learning. First, we construct
different distorted views, or crops, of an image with multi-
crop strategy [10]. More precisely, from a given image, we
generate a set V of different views. This set contains two
global views, xg 1 and xg 2 and several local views of smaller
resolution. All crops are passed through the student while
only the global views are passed through the teacher, there-
fore encouraging “local-to-global” correspondences. We
minimize the loss:
[omitted cross entropy]
This loss is general and can be used on any number of
views, even only 2. However, we follow the standard setting
for multi-crop by using 2 global views at resolution 2242
covering a large (for example greater than 50%) area of the
original image, and several local views of resolution 962
covering only small areas (for example less than 50%) of
the original image. We refer to this setting as the basic
parametrization of DINO, unless mentioned otherwise.

This is implemented in Torch as MultiCropWrapper
https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L183

Conv Utils

Utils for convulsion blocks needed for resnet and resnetx. should have separate file

DINO documentation

Write a technical overview of DINO to include with the config in documentation.

Why Model's init should take its arguments and not a config object

We had some discussion earlier regarding whether or not the __init__ function of our classes should take in a singular params (or config whatever we will call it) OR if they should have their arguments listed out. I wanted to formally create an issue for this topic to discuss and further explain my position fully.

To give an example, consider an encoder-decoder seq2seq model EncoderDecoder that has the architecture:

Embedder -> Encoder -> Decoder -> Linear Layer -> Dropout

Here are hypothetical __init__ functions for each idea.

If we take a params object (the HuggingFace approach):

def __init__(
        self,
        params: Params
):
    self.embedder = Embedder.fromParams(params.pop("embedder"))
    self.encoder = SeqEncoder.fromParams(params.pop("encoder"))
    self.decoder = SeqDecoder.fromParams(params.pop("decoder"))
    self.output_layer= Linear.fromParams(params.pop("output_layer"))
    self.dropout = jax.experimental.stax.dropoout(params.pop("dropout", 0))

    tied_source_embedder_key = params.pop("tied_source_embedder_key", None)
    if tied_source_embedder_key:
        self._tie_source_and_decoder_embedders(tied_source_embedder_key)

Here is how it would look if each argument of __init__ is specified (the AllenNLP approach):

def __init__(
        self,
        embedder:Embedder,
        encoder:SeqEncoder,
        decoder: SeqDecoder,
        output_layer: Linear,
        dropout: int=0,
        tied_source_embedder_key:Optional[str]=None

):
    self.embedder = embedder
    self.encoder = encoder
    self.decoder = decoder
    self.output = output_layer
    self.dropout = jax.experimental.stax.dropoout(dropout)
    
    if tied_source_embedder_key:
        self._tie_source_and_decoder_embedders(self.tied_source_embedder_key)

The difference here is that the fromParams calls are removed. The FromParams class instead handles them. The fromParams class handles determining which constructor to call based on the type annotations and initializes them before passing it to the parent class. In this case, the parent class would be the EncoderDecoder model. Furthermore, it handles arguments and types checking automatically. But this behavior can only be enabled by specifying the arguments in the init signature. Without it, we would need to implement our own checks within the init of the class.

I also believe that the former is much harder for end-users to understand without experience with the library.

I do want others' opinions on this and which direction we should go.

vicreg config

Base vicreg experiment config

Implementation details for pretraining with VICReg on the 1000-classes ImagetNet dataset without
labels are as follows. Coefficients λ and μ are 25 and ν is 1 in Eq. (6), and � is 0.0001 in Eq. (1).
We give more details on how we choose the coefficients of the loss function in Appendix C.3. The
encoder network fθ is a standard ResNet-50 backbone He et al. (2016) with 2048 output units. The
expander hφ is composed of two fully-connected layers with batch normalization (BN) Ioffe &
Szegedy (2015) and ReLU, and a third linear layer. The sizes of all 3 layers were set to 8192. As
with Barlow Twins, performance improves when the size of the expander layers is larger than the
dimension of the representation. The impact of the expander dimension on performance is studied in
Appendix C.5. The training protocol follows those of BYOL and Barlow Twins: LARS optimizer You
et al. (2017); Goyal et al. (2017) run for 1000 epochs with a weight decay of 10−6 and a learning
rate lr = batch_size/256 × base_lr, where batch_size is set to 2048 by default and base_lr is a
base learning rate set to 0.2. The learning rate follows a cosine decay schedule Loshchilov & Hutter
(2017), starting from 0 with 10 warmup epochs and with final value of 0.002.

Implement meter

Task.meter takes logits, targets, weights and returns metrics

DINO centering and sharpening

From https://arxiv.org/abs/2104.14294,

 While our frame-
work can be stabilized with multiple normalizations [10],
it can also work with only a centering and sharpening of
the momentum teacher outputs to avoid model collapse. As
shown experimentally in Section 5.3, centering prevents
one dimension to dominate but encourages collapse to the
uniform distribution, while the sharpening has the oppo-
site effect. Applying both operations balances their effects
which is sufficient to avoid collapse in presence of a momen-
tum teacher. Choosing this method to avoid collapse trades
stability for less dependence over the batch: the centering
operation only depends on first-order batch statistics and
can be interpreted as adding a bias term c to the teacher:
gt(x) ← gt(x) + c. The center c is updated with an expo-
nential moving average, which allows the approach to work
well across different batch sizes as shown in Section 5.5

Implemented in Pytorch here https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L363

def update_center(self, teacher_output):
    """
    Update center used for teacher output.
    """
    batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
    dist.all_reduce(batch_center)
    batch_center = batch_center / (len(teacher_output) * dist.get_world_size())

    # ema update
    self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

    # ema update
    self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

EMA tau scheduler

From BYOL paper,

For the target network, the exponential moving average parameter τ starts from τbase = 0.996 and is increased
to one during training. Specifically, we set τ , 1 −(1 −τbase) ·(cos(πk/K) + 1)/2 with k the current training
step and K the maximum number of training steps

Implement trainer

Add trainer class to manage model training and feature extraction.

SWAV Augmentations/Pipeline

From https://arxiv.org/abs/2006.09882,

We obtain two different views from an image by performing crops of random sizes and aspect ratios.
Specifically we use the RandomResizedCrop method from torchvision.transforms module
of PyTorch with the following scaling parameters: s=(0.14, 1). Note that we sample crops in
a narrower range of scale compared to the default RandomResizedCrop parameters. Then, we
resize both full resolution views to 224 × 224 pixels, unless specified otherwise (we use 160 × 160
resolutions in some of our experiments). Besides, we obtain V additional views by cropping small
parts in the image. To do so, we use the following RandomResizedCrop parameters: s=(0.05,
0.14). We resize the resulting crops to 96 × 96 resolution. Note that we always deal with resolutions
that are divisible by 32 to avoid roundings in the ResNet-50 pooling layers. Finally, we apply random
horizontal flips, color distortion and Gaussian blur to each resulting crop, exactly following the
SimCLR implementation [10]. An illustration of our multi-crop augmentation strategy can be
viewed in Fig. 5

Release 1.0.0

Pytorch libraries:

  1. VISSL https://github.com/facebookresearch/vissl/tree/master/vissl
  2. Pytorch Lightning https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html

API

  • ssljax.models.projection
    • MLP + Softmax
  • ssljax.models.backbone
    • ViT
    • Sparse ViT
    • Vision Longformer
    • Swin Transformer
    • ResNext
    • #2
  • ssljax.augment [just wrap vissl.data?]
  • ssljax.optimizers
  • ssljax.trainers
    • Student-Teacher Schemes
      • DINO
      • BYOL

Repository

Models

Examples

  • Example notebooks for some basic models

Generalising Augmentation Distributions

Right now, the probability of applying a distribution is a property of the Augmentation itself i.e.

class Augmentation:
    """
    An augmentation applies a function to data.

    Args:
        prob (float): Probability that augmentation will be executed.
    """

    def __init__(self, prob=1.0):
        assert isinstance(prob, float), "prob must be of type float"
        self.prob = prob

This variable defines the probability of applying an augmentation or not. This allows for pipelines like:

RandomFlip(0.5) → GaussianBlur(0.5) → ....

However, this framework is not flexible enough to allow for distributions over multiple augmentations:

sample one ( RandomFlip(0.5), GaussianBlur(0.5) ) → ....

This issue proposes refactoring Augmentation to be a deterministic transform with its probability being the property of an AugmentationDistribution class

class AugmentationDistribution:
    """
    A categorical distribution of augmentations to apply to input

    Args:
        probs (List[float]): List of probabilities
        augs (List[Augmentation]): List of augmentations to be sampled
    """
    def __init__(self, probs, augs):
        self.probs = jax.numpy.array(probs)
        self.augs = augs

This will allow us to express the distribution over multiple augmentations mentioned above via:

AugmentationDistribution([0.5,0.5],[Flip,GaussianBlur])

A proof of concept for the above refactor is shown below

import jax

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

class Augmentation:
    def __init__(self, add):
        self.add = add

    def __call__(self, x):
        return self.add+x

class AugmentationDistribution:
    def __init__(self, probs, augs):
        self.probs = jax.numpy.array(probs)
        self.augs = augs
    def apply(self, rng, x):
        key, subkey = jax.random.split(rng)
        sampledIndex = jax.random.choice(subkey,len(self.augs),p=self.probs)
        x = jax.lax.switch(sampledIndex,self.augs,x)
        return x, key

identity = Augmentation(0.0)
addOne = Augmentation(1.0)
addHalf = Augmentation(0.5)

pipeline = [(AugmentationDistribution([0.2,0.8],[addOne,addHalf])),
            (AugmentationDistribution([0.2,0.8],[addOne,identity]))]

@jax.jit
def applyAugPipeline(rng, x):
    for augDist in pipeline:
        x,rng = augDist.apply(rng,x)
    return x

print(applyAugPipeline(key,0))

Implement Augmentations

We define two classes, Augmentation and Pipeline.

class Augmentation:
    def __init__(self, prob)
    def __call__(self, x)

class Pipeline(Augmentation):
    def __init__(self, augmentations)
    def __call__(self, x)
        # this just composes augmentations
    def sample(self)

Use cases:

  1. compose a list of augmentations in order
  2. sample from a set of augmentations (with and without replacement)
  3. mixed in order and sampling (eg execute a fixed first augmentation, then sample the middle augmentations, then execute a fixed final augmentation)

Augmentations in the model forward pass or in the dataloader?

There isn't consensus in reference implementations about where to augment.

Lucidrains augments in (his equivalent to) the BaseSSL model forward pass
https://github.com/lucidrains/byol-pytorch/blob/master/byol_pytorch/byol_pytorch.py

DeepMind augments in the forward pass (in the loss function!)

Another natural thing to do is augment in the dataloader (augmentation happens right before passing data to Body)

We should benchmark different approaches

TODO 1.0.0

  • Metrics @gabeorlanski
  • Tensorboard #55
  • Logging
  • DynamicScaling @ryanccarelli #52
  • Gradient Accumulation @AkashGanesan #56
  • Checkpointing @ryanccarelli #55
  • Improved config (.ipynb) @AkashGanesan
  • Modules from pretrained
  • BYOL @ryanccarelli
    • Implementation
      • Separate preprocess and postprocess pipelines
      • Fix RandomCrop augmentation
    • Benchmarks
  • DINO
    • Implementation
      • Centering Layer
      def update_center(self, teacher_output):
          """
          Update center used for teacher output.
          """
          batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
          dist.all_reduce(batch_center)
          batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
      
          # ema update
          self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
      
          # ema update
          self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
      • DINO/SWAV MLP
        ("The projection head consists of a 3-layer multi-layer perceptron (MLP) with hidden
        dimension 2048 followed by `2 normalization and a weight normalized fully connected
        layer [61] with K dimensions, which is similar to the design from SwAV")
        https://github.com/facebookresearch/dino
      class DINOHead(nn.Module):
          def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
              super().__init__()
              nlayers = max(nlayers, 1)
              if nlayers == 1:
                  self.mlp = nn.Linear(in_dim, bottleneck_dim)
              else:
                  layers = [nn.Linear(in_dim, hidden_dim)]
                  if use_bn:
                      layers.append(nn.BatchNorm1d(hidden_dim))
                  layers.append(nn.GELU())
                  for _ in range(nlayers - 2):
                      layers.append(nn.Linear(hidden_dim, hidden_dim))
                      if use_bn:
                          layers.append(nn.BatchNorm1d(hidden_dim))
                      layers.append(nn.GELU())
                  layers.append(nn.Linear(hidden_dim, bottleneck_dim))
                  self.mlp = nn.Sequential(*layers)
              self.apply(self._init_weights)
              self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
              self.last_layer.weight_g.data.fill_(1)
              if norm_last_layer:
                  self.last_layer.weight_g.requires_grad = False
      
          def _init_weights(self, m):
              if isinstance(m, nn.Linear):
                  trunc_normal_(m.weight, std=.02)
                  if isinstance(m, nn.Linear) and m.bias is not None:
                      nn.init.constant_(m.bias, 0)
      
          def forward(self, x):
              x = self.mlp(x)
              x = nn.functional.normalize(x, dim=-1, p=2)
              x = self.last_layer(x)
              return x
    • Benchmarks
  • SWAV
    • Implementation
    • Benchmarks
  • Documentation
    • core
    • introduction/overview
    • model descriptions
  • Tests
  • Load pretrained model for individual component (only ViT for example)

Distributed sampler

Which way will we support distributed sampler

From DINO (torch.data)

    transform = DataAugmentationDINO(
        args.global_crops_scale,
        args.local_crops_scale,
        args.local_crops_number,
    )
    dataset = datasets.ImageFolder(args.data_path, transform=transform)
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

From BYOL (tfds) https://github.com/deepmind/deepmind-research/blob/2c7c401024c42c4fb1aa20a8b0471d2e6b480906/byol/utils/dataset.py#L56

class PreprocessMode(enum.Enum):
  """Preprocessing modes for the dataset."""
  PRETRAIN = 1  # Generates two augmented views (random crop + augmentations).
  LINEAR_TRAIN = 2  # Generates a single random crop.
  EVAL = 3  # Generates a single center crop.


def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
  """Normalize the image using ImageNet statistics."""
  mean_rgb = (0.485, 0.456, 0.406)
  stddev_rgb = (0.229, 0.224, 0.225)
  normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
  normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
  return normed_images


def load(split: Split,
         *,
         preprocess_mode: PreprocessMode,
         batch_dims: Sequence[int],
         transpose: bool = False,
         allow_caching: bool = False) -> Generator[Batch, None, None]:
  """Loads the given split of the dataset."""
  start, end = _shard(split, jax.host_id(), jax.host_count())

  total_batch_size = np.prod(batch_dims)

  tfds_split = tfds.core.ReadInstruction(
      _to_tfds_split(split), from_=start, to=end, unit='abs')
  ds = tfds.load(
      'imagenet2012:5.*.*',
      split=tfds_split,
      decoders={'image': tfds.decode.SkipDecoding()})

  options = tf.data.Options()
  options.experimental_threading.private_threadpool_size = 48
  options.experimental_threading.max_intra_op_parallelism = 1

  if preprocess_mode is not PreprocessMode.EVAL:
    options.experimental_deterministic = False
    if jax.host_count() > 1 and allow_caching:
      # Only cache if we are reading a subset of the dataset.
      ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)

  else:
    if split.num_examples % total_batch_size != 0:
      raise ValueError(f'Test/valid must be divisible by {total_batch_size}')

  ds = ds.with_options(options)

  def preprocess_pretrain(example):
    view1 = _preprocess_image(example['image'], mode=preprocess_mode)
    view2 = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'view1': view1, 'view2': view2, 'labels': label}

  def preprocess_linear_train(example):
    image = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'images': image, 'labels': label}

  def preprocess_eval(example):
    image = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'images': image, 'labels': label}

  if preprocess_mode is PreprocessMode.PRETRAIN:
    ds = ds.map(
        preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
    ds = ds.map(
        preprocess_linear_train,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
  else:
    ds = ds.map(
        preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  def transpose_fn(batch):
    # We use the double-transpose-trick to improve performance for TPUs. Note
    # that this (typically) requires a matching HWCN->NHWC transpose in your
    # model code. The compiler cannot make this optimization for us since our
    # data pipeline and model are compiled separately.
    batch = dict(**batch)
    if preprocess_mode is PreprocessMode.PRETRAIN:
      batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
      batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
    else:
      batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
    return batch

  for i, batch_size in enumerate(reversed(batch_dims)):
    ds = ds.batch(batch_size)
    if i == 0 and transpose:
      ds = ds.map(transpose_fn)  # NHWC -> HWCN

  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

  yield from tfds.as_numpy(ds)

Separate module from model definition, instantiate modules in SSLModel

Currently we instantiate modules in Branch by iterating over each branch's config self.config.model.branches.items(). This change instantiates modules instead in SSLModel, then passes these built modules to branches.

Advantages:

  1. Parameter sharing: We can easily tie parameters by calling the same module twice in different branches.
  2. Modularity: Branches are now just wrappers around composition.

Here is a scheme with parameter sharing:

modules:
  body1:
    name: ResNet
    params: {}
  head1:
    name: MLP
    params: {}

model:
  name: SSLModel
  branches:
    0:
      body: body1
      head: head1
    1:
      body: body1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.