ryanccarelli / ssljax Goto Github PK
View Code? Open in Web Editor NEWself-supervised learning in jax
License: GNU General Public License v2.0
self-supervised learning in jax
License: GNU General Public License v2.0
Implement mm(z, C) as below
for x in loader: # load a batch x with B samples
x_t = t(x) # t is a random augmentation
x_s = s(x) # s is a another random augmentation
z = model(cat(x_t, x_s)) # embeddings: 2BxD
scores = mm(z, C) # prototype scores: 2BxK
scores_t = scores[:B]
scores_s = scores[B:]
# compute assignments
with torch.no_grad():
q_t = sinkhorn(scores_t)
q_s = sinkhorn(scores_s
Implement Sinkhorn-Knopp as in https://arxiv.org/abs/2006.09882
# Sinkhorn-Knopp
def sinkhorn(scores, eps=0.05, niters=3):
Q = exp(scores / eps).T
Q /= sum(Q)
K, B = Q.shape
u, r, c = zeros(K), ones(K) / K, ones(B) / B
for _ in range(niters):
u = sum(Q, dim=1)
Q *= (r / u).unsqueeze(1)
Q *= (c / sum(Q, dim=0)).unsqueeze(0)
return (Q / sum(Q, dim=0, keepdim=True)).T
DINO uses a multicrop scheme where the student processes "local views" 96x96 and compares the representation (by CCE) to "global views" 224x224.
In the following, we detail how we adapt the problem
in Eq. (2) to self-supervised learning. First, we construct
different distorted views, or crops, of an image with multi-
crop strategy [10]. More precisely, from a given image, we
generate a set V of different views. This set contains two
global views, xg 1 and xg 2 and several local views of smaller
resolution. All crops are passed through the student while
only the global views are passed through the teacher, there-
fore encouraging “local-to-global” correspondences. We
minimize the loss:
[omitted cross entropy]
This loss is general and can be used on any number of
views, even only 2. However, we follow the standard setting
for multi-crop by using 2 global views at resolution 2242
covering a large (for example greater than 50%) area of the
original image, and several local views of resolution 962
covering only small areas (for example less than 50%) of
the original image. We refer to this setting as the basic
parametrization of DINO, unless mentioned otherwise.
This is implemented in Torch as MultiCropWrapper
https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L183
https://nvidia.github.io/apex/fp16_utils.html
This is supported in flax https://flax.readthedocs.io/en/latest/flax.optim.html.
Utils for convulsion blocks needed for resnet and resnetx. should have separate file
Implement loss for swapped prediction as in https://arxiv.org/pdf/2006.09882.pdf
Currently, we need the post process key to be in yaml. Add code to allow for this to be absent.
Set config for vit in dino_conf.yaml.
AllenNLP uses a registry system to allow seamless switching between classes.
Here is their implementation: https://github.com/allenai/allennlp/blob/main/allennlp/common/registrable.py
We do not need the fromParams
as that is their json
loader stuff.
Write a technical overview of DINO to include with the config in documentation.
We had some discussion earlier regarding whether or not the __init__
function of our classes should take in a singular params
(or config whatever we will call it) OR if they should have their arguments listed out. I wanted to formally create an issue for this topic to discuss and further explain my position fully.
To give an example, consider an encoder-decoder seq2seq model EncoderDecoder
that has the architecture:
Embedder -> Encoder -> Decoder -> Linear Layer -> Dropout
Here are hypothetical __init__
functions for each idea.
If we take a params
object (the HuggingFace approach):
def __init__(
self,
params: Params
):
self.embedder = Embedder.fromParams(params.pop("embedder"))
self.encoder = SeqEncoder.fromParams(params.pop("encoder"))
self.decoder = SeqDecoder.fromParams(params.pop("decoder"))
self.output_layer= Linear.fromParams(params.pop("output_layer"))
self.dropout = jax.experimental.stax.dropoout(params.pop("dropout", 0))
tied_source_embedder_key = params.pop("tied_source_embedder_key", None)
if tied_source_embedder_key:
self._tie_source_and_decoder_embedders(tied_source_embedder_key)
Here is how it would look if each argument of __init__
is specified (the AllenNLP approach):
def __init__(
self,
embedder:Embedder,
encoder:SeqEncoder,
decoder: SeqDecoder,
output_layer: Linear,
dropout: int=0,
tied_source_embedder_key:Optional[str]=None
):
self.embedder = embedder
self.encoder = encoder
self.decoder = decoder
self.output = output_layer
self.dropout = jax.experimental.stax.dropoout(dropout)
if tied_source_embedder_key:
self._tie_source_and_decoder_embedders(self.tied_source_embedder_key)
The difference here is that the fromParams
calls are removed. The FromParams
class instead handles them. The fromParams
class handles determining which constructor to call based on the type annotations and initializes them before passing it to the parent class. In this case, the parent class would be the EncoderDecoder
model. Furthermore, it handles arguments and types checking automatically. But this behavior can only be enabled by specifying the arguments in the init
signature. Without it, we would need to implement our own checks within the init of the class.
I also believe that the former is much harder for end-users to understand without experience with the library.
I do want others' opinions on this and which direction we should go.
Something like https://github.com/facebookresearch/vissl/blob/master/vissl/models/base_ssl_model.py
This work is on branch dev-sslcore
Base vicreg experiment config
Implementation details for pretraining with VICReg on the 1000-classes ImagetNet dataset without
labels are as follows. Coefficients λ and μ are 25 and ν is 1 in Eq. (6), and � is 0.0001 in Eq. (1).
We give more details on how we choose the coefficients of the loss function in Appendix C.3. The
encoder network fθ is a standard ResNet-50 backbone He et al. (2016) with 2048 output units. The
expander hφ is composed of two fully-connected layers with batch normalization (BN) Ioffe &
Szegedy (2015) and ReLU, and a third linear layer. The sizes of all 3 layers were set to 8192. As
with Barlow Twins, performance improves when the size of the expander layers is larger than the
dimension of the representation. The impact of the expander dimension on performance is studied in
Appendix C.5. The training protocol follows those of BYOL and Barlow Twins: LARS optimizer You
et al. (2017); Goyal et al. (2017) run for 1000 epochs with a weight decay of 10−6 and a learning
rate lr = batch_size/256 × base_lr, where batch_size is set to 2048 by default and base_lr is a
base learning rate set to 0.2. The learning rate follows a cosine decay schedule Loshchilov & Hutter
(2017), starting from 0 with 10 warmup epochs and with final value of 0.002.
Working implementation of https://arxiv.org/abs/2006.07733
Fails at pip install
Implement loss function from https://arxiv.org/pdf/2105.04906.pdf
Task.meter takes logits, targets, weights and returns metrics
From https://arxiv.org/abs/2104.14294,
While our frame-
work can be stabilized with multiple normalizations [10],
it can also work with only a centering and sharpening of
the momentum teacher outputs to avoid model collapse. As
shown experimentally in Section 5.3, centering prevents
one dimension to dominate but encourages collapse to the
uniform distribution, while the sharpening has the oppo-
site effect. Applying both operations balances their effects
which is sufficient to avoid collapse in presence of a momen-
tum teacher. Choosing this method to avoid collapse trades
stability for less dependence over the batch: the centering
operation only depends on first-order batch statistics and
can be interpreted as adding a bias term c to the teacher:
gt(x) ← gt(x) + c. The center c is updated with an expo-
nential moving average, which allows the approach to work
well across different batch sizes as shown in Section 5.5
Implemented in Pytorch here https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L363
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
dist.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
# ema update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
# ema update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
Implement SSL losses.
https://github.com/ryanccarelli/ssljax/tree/dev-losses/ssljax/losses
From BYOL paper,
For the target network, the exponential moving average parameter τ starts from τbase = 0.996 and is increased
to one during training. Specifically, we set τ , 1 −(1 −τbase) ·(cos(πk/K) + 1)/2 with k the current training
step and K the maximum number of training steps
Increase test coverage by mocking configs
Our random crop augmentation should be made random
3-layer MLP (2048 hidden dim) with GELU, last layer without GELU
l2 norm
weight normalized fully connected layer
no batchnorm
https://github.com/ryanccarelli/ssljax/tree/dev-losses/ssljax/losses
Need to implement the losses
Add trainer class to manage model training and feature extraction.
How do we accumulate batch stats across multiple gradient steps? Is this something that is done or does the complexity outweigh the benefits for doing this?
See https://github.com/facebookresearch/vissl/tree/master/vissl/models/heads
Need to implement eval loop so we know if our models do things
Wrap and register optax optimizers.
From https://arxiv.org/abs/2006.09882,
We obtain two different views from an image by performing crops of random sizes and aspect ratios.
Specifically we use the RandomResizedCrop method from torchvision.transforms module
of PyTorch with the following scaling parameters: s=(0.14, 1). Note that we sample crops in
a narrower range of scale compared to the default RandomResizedCrop parameters. Then, we
resize both full resolution views to 224 × 224 pixels, unless specified otherwise (we use 160 × 160
resolutions in some of our experiments). Besides, we obtain V additional views by cropping small
parts in the image. To do so, we use the following RandomResizedCrop parameters: s=(0.05,
0.14). We resize the resulting crops to 96 × 96 resolution. Note that we always deal with resolutions
that are divisible by 32 to avoid roundings in the ResNet-50 pooling layers. Finally, we apply random
horizontal flips, color distortion and Gaussian blur to each resulting crop, exactly following the
SimCLR implementation [10]. An illustration of our multi-crop augmentation strategy can be
viewed in Fig. 5
Pytorch libraries:
API
Repository
Models
Examples
We need to support flexible evals, like linear, knn, or other task. Evals will be specified in their own configuration files.
Different evals require different model definition (ex. removing the head in linear evals)
Different evals require different data preprocessing, as in https://github.com/deepmind/deepmind-research/blob/2c7c401024c42c4fb1aa20a8b0471d2e6b480906/byol/utils/dataset.py#L56
Right now, the probability of applying a distribution is a property of the Augmentation itself i.e.
class Augmentation:
"""
An augmentation applies a function to data.
Args:
prob (float): Probability that augmentation will be executed.
"""
def __init__(self, prob=1.0):
assert isinstance(prob, float), "prob must be of type float"
self.prob = prob
This variable defines the probability of applying an augmentation or not. This allows for pipelines like:
RandomFlip(0.5) → GaussianBlur(0.5) → ....
However, this framework is not flexible enough to allow for distributions over multiple augmentations:
sample one ( RandomFlip(0.5), GaussianBlur(0.5) ) → ....
This issue proposes refactoring Augmentation to be a deterministic transform with its probability being the property of an AugmentationDistribution class
class AugmentationDistribution:
"""
A categorical distribution of augmentations to apply to input
Args:
probs (List[float]): List of probabilities
augs (List[Augmentation]): List of augmentations to be sampled
"""
def __init__(self, probs, augs):
self.probs = jax.numpy.array(probs)
self.augs = augs
This will allow us to express the distribution over multiple augmentations mentioned above via:
AugmentationDistribution([0.5,0.5],[Flip,GaussianBlur])
A proof of concept for the above refactor is shown below
import jax
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
class Augmentation:
def __init__(self, add):
self.add = add
def __call__(self, x):
return self.add+x
class AugmentationDistribution:
def __init__(self, probs, augs):
self.probs = jax.numpy.array(probs)
self.augs = augs
def apply(self, rng, x):
key, subkey = jax.random.split(rng)
sampledIndex = jax.random.choice(subkey,len(self.augs),p=self.probs)
x = jax.lax.switch(sampledIndex,self.augs,x)
return x, key
identity = Augmentation(0.0)
addOne = Augmentation(1.0)
addHalf = Augmentation(0.5)
pipeline = [(AugmentationDistribution([0.2,0.8],[addOne,addHalf])),
(AugmentationDistribution([0.2,0.8],[addOne,identity]))]
@jax.jit
def applyAugPipeline(rng, x):
for augDist in pipeline:
x,rng = augDist.apply(rng,x)
return x
print(applyAugPipeline(key,0))
Need to add our copyright headers to every file AND adhere to the Apache 2.0 License for any of the AllenNLP code we use.
We need an ema scheduler for the momentum encoder.
We define two classes, Augmentation and Pipeline.
class Augmentation:
def __init__(self, prob)
def __call__(self, x)
class Pipeline(Augmentation):
def __init__(self, augmentations)
def __call__(self, x)
# this just composes augmentations
def sample(self)
Use cases:
See vissl/vissl/config/defaults.yaml for experiment details. This will have the following configuration sections
There isn't consensus in reference implementations about where to augment.
Lucidrains augments in (his equivalent to) the BaseSSL model forward pass
https://github.com/lucidrains/byol-pytorch/blob/master/byol_pytorch/byol_pytorch.py
DeepMind augments in the forward pass (in the loss function!)
Another natural thing to do is augment in the dataloader (augmentation happens right before passing data to Body)
We should benchmark different approaches
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
dist.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
# ema update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
# ema update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
Which way will we support distributed sampler
From DINO (torch.data)
transform = DataAugmentationDINO(
args.global_crops_scale,
args.local_crops_scale,
args.local_crops_number,
)
dataset = datasets.ImageFolder(args.data_path, transform=transform)
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
From BYOL (tfds) https://github.com/deepmind/deepmind-research/blob/2c7c401024c42c4fb1aa20a8b0471d2e6b480906/byol/utils/dataset.py#L56
class PreprocessMode(enum.Enum):
"""Preprocessing modes for the dataset."""
PRETRAIN = 1 # Generates two augmented views (random crop + augmentations).
LINEAR_TRAIN = 2 # Generates a single random crop.
EVAL = 3 # Generates a single center crop.
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
"""Normalize the image using ImageNet statistics."""
mean_rgb = (0.485, 0.456, 0.406)
stddev_rgb = (0.229, 0.224, 0.225)
normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
return normed_images
def load(split: Split,
*,
preprocess_mode: PreprocessMode,
batch_dims: Sequence[int],
transpose: bool = False,
allow_caching: bool = False) -> Generator[Batch, None, None]:
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
total_batch_size = np.prod(batch_dims)
tfds_split = tfds.core.ReadInstruction(
_to_tfds_split(split), from_=start, to=end, unit='abs')
ds = tfds.load(
'imagenet2012:5.*.*',
split=tfds_split,
decoders={'image': tfds.decode.SkipDecoding()})
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1
if preprocess_mode is not PreprocessMode.EVAL:
options.experimental_deterministic = False
if jax.host_count() > 1 and allow_caching:
# Only cache if we are reading a subset of the dataset.
ds = ds.cache()
ds = ds.repeat()
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
else:
if split.num_examples % total_batch_size != 0:
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
ds = ds.with_options(options)
def preprocess_pretrain(example):
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
view2 = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'view1': view1, 'view2': view2, 'labels': label}
def preprocess_linear_train(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
def preprocess_eval(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
if preprocess_mode is PreprocessMode.PRETRAIN:
ds = ds.map(
preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
ds = ds.map(
preprocess_linear_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
ds = ds.map(
preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def transpose_fn(batch):
# We use the double-transpose-trick to improve performance for TPUs. Note
# that this (typically) requires a matching HWCN->NHWC transpose in your
# model code. The compiler cannot make this optimization for us since our
# data pipeline and model are compiled separately.
batch = dict(**batch)
if preprocess_mode is PreprocessMode.PRETRAIN:
batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
else:
batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
return batch
for i, batch_size in enumerate(reversed(batch_dims)):
ds = ds.batch(batch_size)
if i == 0 and transpose:
ds = ds.map(transpose_fn) # NHWC -> HWCN
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
yield from tfds.as_numpy(ds)
For SimSiam we must support that modules in different branches share weights.
Currently we instantiate modules in Branch by iterating over each branch's config self.config.model.branches.items(). This change instantiates modules instead in SSLModel, then passes these built modules to branches.
Advantages:
Here is a scheme with parameter sharing:
modules:
body1:
name: ResNet
params: {}
head1:
name: MLP
params: {}
model:
name: SSLModel
branches:
0:
body: body1
head: head1
1:
body: body1
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.