transform = DataAugmentationDINO(
args.global_crops_scale,
args.local_crops_scale,
args.local_crops_number,
)
dataset = datasets.ImageFolder(args.data_path, transform=transform)
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
class PreprocessMode(enum.Enum):
"""Preprocessing modes for the dataset."""
PRETRAIN = 1 # Generates two augmented views (random crop + augmentations).
LINEAR_TRAIN = 2 # Generates a single random crop.
EVAL = 3 # Generates a single center crop.
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
"""Normalize the image using ImageNet statistics."""
mean_rgb = (0.485, 0.456, 0.406)
stddev_rgb = (0.229, 0.224, 0.225)
normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
return normed_images
def load(split: Split,
*,
preprocess_mode: PreprocessMode,
batch_dims: Sequence[int],
transpose: bool = False,
allow_caching: bool = False) -> Generator[Batch, None, None]:
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
total_batch_size = np.prod(batch_dims)
tfds_split = tfds.core.ReadInstruction(
_to_tfds_split(split), from_=start, to=end, unit='abs')
ds = tfds.load(
'imagenet2012:5.*.*',
split=tfds_split,
decoders={'image': tfds.decode.SkipDecoding()})
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1
if preprocess_mode is not PreprocessMode.EVAL:
options.experimental_deterministic = False
if jax.host_count() > 1 and allow_caching:
# Only cache if we are reading a subset of the dataset.
ds = ds.cache()
ds = ds.repeat()
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
else:
if split.num_examples % total_batch_size != 0:
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
ds = ds.with_options(options)
def preprocess_pretrain(example):
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
view2 = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'view1': view1, 'view2': view2, 'labels': label}
def preprocess_linear_train(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
def preprocess_eval(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
if preprocess_mode is PreprocessMode.PRETRAIN:
ds = ds.map(
preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
ds = ds.map(
preprocess_linear_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
ds = ds.map(
preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def transpose_fn(batch):
# We use the double-transpose-trick to improve performance for TPUs. Note
# that this (typically) requires a matching HWCN->NHWC transpose in your
# model code. The compiler cannot make this optimization for us since our
# data pipeline and model are compiled separately.
batch = dict(**batch)
if preprocess_mode is PreprocessMode.PRETRAIN:
batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
else:
batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
return batch
for i, batch_size in enumerate(reversed(batch_dims)):
ds = ds.batch(batch_size)
if i == 0 and transpose:
ds = ds.map(transpose_fn) # NHWC -> HWCN
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
yield from tfds.as_numpy(ds)