Code Monkey home page Code Monkey logo

jepa's Introduction

V-JEPA: Video Joint Embedding Predictive Architecture

Official PyTorch codebase for the video joint-embedding predictive architecture, V-JEPA, a method for self-supervised learning of visual representations from video.

Meta AI Research, FAIR

Adrien Bardes, Quentin Garrido, Jean Ponce, Xinlei Chen, Michael Rabbat, Yann LeCun, Mahmoud Assran*, Nicolas Ballas*

[Blog] [Paper] [Yannic Kilcher's Video]

V-JEPA models are trained by passively watching video pixels from the VideoMix2M dataset, and produce versatile visual representations that perform well on downstream video and image tasks, without adaption of the model’s parameters; e.g., using a frozen backbone and only a light-weight task-specific attentive probe.

Method

V-JEPA pretraining is based solely on an unsupervised feature prediction objective, and does not utilize pretrained image encoders, text, negative examples, human annotations, or pixel-level reconstruction.

     

Visualizations

As opposed to generative methods that have a pixel decoder, V-JEPA has a predictor that makes predictions in latent space. We train a conditional diffusion model to decode the V-JEPA feature-space predictions to interpretable pixels; the pretrained V-JEPA encoder and predictor networks are kept frozen in this process. The decoder is only fed the representations predicted for the missing regions of the video, and does not have access to the unmasked regions of the video.

The V-JEPA feature predictions are indeed grounded, and exhibit spatio-temporal consistency with the unmasked regions of the video.



MODEL ZOO

Pretrained models

model patch size resolution iterations batch size data download
ViT-L 2x16x16 224x224 90K 3072 VideoMix2M checkpoint configs
ViT-H 2x16x16 224x224 90K 3072 VideoMix2M checkpoint configs
ViT-H 2x16x16 384x384 90K 2400 VideoMix2M checkpoint configs

K400 Attentive probes

model resolution accuracy (16x8x3) download
ViT-L/16 224x224 80.8 attentive probe checkpoint configs
ViT-H/16 224x224 82.0 attentive probe checkpoint configs
ViT-H/16 384x384 81.9 attentive probe checkpoint configs

SSv2 Attentive probes

model resolution accuracy (16x2x3) download
ViT-L/16 224x224 69.5 attentive probe checkpoint configs
ViT-H/16 224x224 71.4 attentive probe checkpoint configs
ViT-H/16 384x384 72.2 attentive probe checkpoint configs

ImageNet1K Attentive probes

model resolution accuracy download
ViT-L/16 224x224 74.8 attentive probe checkpoint configs
ViT-H/16 224x224 75.9 attentive probe checkpoint configs
ViT-H/16 384x384 77.4 attentive probe checkpoint configs

Places205 Attentive probes

model resolution accuracy download
ViT-L/16 224x224 60.3 attentive probe checkpoint configs
ViT-H/16 224x224 61.7 attentive probe checkpoint configs
ViT-H/16 384x384 62.8 attentive probe checkpoint configs

iNat21 Attentive probes

model resolution accuracy download
ViT-L/16 224x224 67.8 attentive probe checkpoint configs
ViT-H/16 224x224 67.9 attentive probe checkpoint configs
ViT-H/16 384x384 72.6 attentive probe checkpoint configs

Code Structure

Config files: All experiment parameters are specified in config files (as opposed to command-line arguments). See the configs/ directory for example config files. Note, before launching an experiment, you must update the paths in the config file to point to your own directories, indicating where to save the logs and checkpoints and where to find the training data.

.
├── app                       # the only place where training loops are allowed
│   ├── vjepa                 #   Video JEPA pre-training
│   ├── main_distributed.py   #   entrypoint for launching app on slurm cluster
│   └── main.py               #   entrypoint for launching app locally on your machine for debugging
├── evals                     # the only place where evaluation of 'apps' are allowed
│   ├── image_classification  #   training an attentive probe for image classification with frozen backbone
│   ├── video_classification  #   training an attentive probe for video classification with frozen backbone
│   ├── main_distributed.py   #   entrypoint for launching distributed evaluations on slurm cluster
│   └── main.py               #   entrypoint for launching evaluations locally on your machine for debugging
├── src                       # the package
│   ├── datasets              #   datasets, data loaders, ...
│   ├── models                #   model definitions
│   ├── masks                 #   mask collators, masking utilities, ...
│   └── utils                 #   shared utilities
└── configs                   # the only place where config files are allowed (specify experiment params for app/eval runs)
    ├── evals                 #   configs for launching vjepa frozen evaluations
    └── pretrain              #   configs for launching vjepa pretraining

Data preparation

Video Datasets

V-JEPA pretraining and evaluations work with many standard video formats. To make a video dataset compatible with the V-JEPA codebase, you simply need to create a .csv file with the following format and then specify the path to this CSV file in your config.

/absolute_file_path.[mp4, webvid, etc.] $integer_class_label
/absolute_file_path.[mp4, webvid, etc.] $integer_class_label
/absolute_file_path.[mp4, webvid, etc.] $integer_class_label
...

Since V-JEPA is entirely unsupervised, the pretraining code will disregard the $integer_class_label in the CSV file. Thus, feel free to put a random value in this column. However, if you wish to run a supervised video classification evaluation on your video dataset, you must replace $integer_class_label with the ground truth label for each video.

Image Datasets

We use the standard PyTorch ImageFolder class in our image classification evals. Thus, to set up an image dataset for the image classification evaluation, first create a directory to store your image datasets $your_directory_containing_image_datasets. Next, download your image datasets into this directory in a format compatible with PyTorch ImageFolder.

For example, suppose we have a directory called my_image_datasets. We would then download our image datasets into this directory so that we end up with the following file tree

.
└── /my_image_datasets/                # where we store image datasets
    ├── places205/121517/pytorch/      #   Places205
    │   └── [...]
    ├── iNaturalist-2021/110421/       #   iNaturalist21
    │   └── [...]
    ├── [...]                          #   Other Image Datasets
    │   └── [...]
    └── imagenet_full_size/061417/     #   ImageNet1k
        └── train
        │   ├── $class_1
        │   │    ├── xxx.[png, jpeg, etc.]
        │   │    ├── [...]
        │   │    └── xxz.[png, jpeg, etc.]
        │   ├── [...]
        │   └── $class_n
        │       ├── abc.[png, jpeg, etc.]
        │       ├── [...]
        │       └── abz.[png, jpeg, etc.]
        └── val
            ├── $class_1
            │    ├── xxx.[png, jpeg, etc.]
            │    ├── [...]
            │    └── xxz.[png, jpeg, etc.]
            ├── [...]
            └── $class_n
                ├── abc.[png, jpeg, etc.]
                ├── [...]
                └── abz.[png, jpeg, etc.]

Launching V-JEPA pretraining

Local training

If you wish to debug your code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing our results requires launching distributed training.

The single-machine implementation starts from the app/main.py, which parses the experiment config file and runs the pretraining locally on a multi-GPU (or single-GPU) machine. For example, to run V-JEPA pretraining on GPUs "0", "1", and "2" on a local machine using the config configs/pretrain/vitl16.yaml, type the command:

python -m app.main \
  --fname configs/pretrain/vitl16.yaml \
  --devices cuda:0 cuda:1 cuda:2

Distributed training

To launch a distributed training run, the implementation starts from app/main_distributed.py, which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster.

For example, to launch a distributed pre-training experiment using the config configs/pretrain/vitl16.yaml, type the command:

python -m app.main_distributed \
  --fname configs/pretrain/vitl16.yaml \
  --folder $path_to_save_stderr_and_stdout \
  --partition $slurm_partition

Launching Evaluations

Local training

If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. The single-machine implementation starts from the eval/main.py, which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine.

For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config configs/eval/vitl16_in1k.yaml, type the command:

python -m evals.main \
  --fname configs/eval/vitl16_in1k.yaml \
  --devices cuda:0 cuda:1 cuda:2

Distributed training

To launch a distributed evaluation run, the implementation starts from eval/main_distributed.py, which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster.

For example, to launch a distributed ImageNet image classification experiment using the config configs/eval/vitl16_in1k.yaml, type the command:

python -m evals.main_distributed \
  --fname configs/eval/vitl16_in1k.yaml \
  --folder $path_to_save_stderr_and_stdout \
  --partition $slurm_partition

Similarly, to launch a distributed K400 video classification experiment using the config configs/eval/vitl16_k400.yaml, type the command:

python -m evals.main_distributed \
  --fname configs/eval/vitl16_k400.yaml \
  --folder $path_to_save_stderr_and_stdout \
  --partition $slurm_partition

Setup

Run:

conda create -n jepa python=3.9 pip
conda activate jepa
python setup.py install

License

See the LICENSE file for details about the license under which this code is made available.

Citation

If you find this repository useful in your research, please consider giving a star ⭐ and a citation

@article{bardes2024revisiting,
  title={Revisiting Feature Prediction for Learning Visual Representations from Video},
  author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas},
  journal={arXiv:2404.08471},
  year={2024}
}

jepa's People

Contributors

bryant1410 avatar eltociear avatar midoassran avatar orena1 avatar shade5 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jepa's Issues

ViT-tiny encoder

Hi, I can see there are various encoder class configs including a ViT-mini. Will you be releasing one in the future?

data loading during eval

Hi,
You don't seem to utilize the 'shared_transform' feature during evaluation; instead, you opt to perform this transform per clip.
I wondered why, as during the, all the transforms should be the same. Therefore, why would you apply it per clip instead of the entire video?
Best,
Orr

Evaluation / Testing data

I want to evaluate a basic imagenet data, although i have downloaded config files and pretrained models from the github site provided, i get following error. I tried to use ImageNet-1K data, i could not find the imagenet_full_size/061417/ data. Can you please provide more infrmation about how to test / evaluate data using welknown imagenet dataset. Or can your provide any support file / info for imagenet_full_size/061417/

File "/home/ayhan/jepa/evals/image_classification_frozen/eval.py", line 363, in load_pretrained
pretrained_dict = checkpoint['encoder']
KeyError: 'encoder'

Thank you.

Augmentations in the frozen evaluation

Shouldn't the frozen evaluation not use augmentations?

After looking at the code and reading the paper, I see that you apply random augmentations when computing the embeddings. The weights of the encoder are frozen. However, the evaluation is not frozen; a video will get different embeddings in different epochs. This is a bit misleading because these results wouldn't apply if I pre-extract the embeddings (if they are really frozen).

I think it'd be nice to see the performance of your models in such a setting. It's just a suggestion of something I believe others and I would find useful, but I understand if you can't do it for any reason.

(I assume the baselines do the same with the augmentations -- still my concern applies)

KeyError when running evals.main

Thanks for your brilliant work! Having downloaded K400 pretrained checkpoint file(k400-probe.pth.tar) and modified the config yaml file for the corresponding dataset(specifying datapath), I ran evals.main, and got the KeyError saying: jepa/evals/video_classification_frozen/eval.py", line 424, in load_pretrained
pretrained_dict = checkpoint[checkpoint_key]
KeyError: 'target_encoder
Then I print the keys in checkpoint, and get the following keys:
keys in ckpt: ['classifier', 'opt', 'scaler', 'epoch', 'batch_size', 'world_size', 'lr']
It seems that there's no key named "target_encoder". What's the reason behind this error?
Looking forward to your reply

Any plan of releasing decoder checkpoint?

Hi V-Jepa team,

Thank you for the great work! I am wondering if you have any plans to release the checkpoint for the decoder (reconstruction images from the generated latents). I am considering utilizing the V-jepa model for specific experiments, while performing these experiments is challenging without the decoder.

PCA feature map visualization of a pre-trained weights look very random, compared to without pre-trained weights loaded

Hi,

Thank you for this amazing project.

I have been exploring the feature maps produced by the pre-trained V-JEPA, using PCA component visualization.
image

However, the feature maps look very random, so I try doing the same thing without the pre-trained weight.
image

Were the feature maps from the V-JEPA pre-training supposed to be like this, or what did I missed in loading the pretrained weight?

Here is the code I used to do the feature visualization.

# %%
from evals.video_classification_frozen.eval import make_dataloader
import matplotlib.pyplot as plt
import torch
import yaml
import numpy as np
import torch.nn.functional as F
from app.vjepa.utils import (
    init_video_model,
)

# %%
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
    # features: (N, C)
    # m: a hyperparam controlling how many std dev outside for outliers
    assert len(features.shape) == 2, "features should be (N, C)"
    reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
    colors = features @ reduction_mat
    if remove_first_component:
        colors_min = colors.min(dim=0).values
        colors_max = colors.max(dim=0).values
        tmp_colors = (colors - colors_min) / (colors_max - colors_min)
        fg_mask = tmp_colors[..., 0] < 0.2
        reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
        colors = features @ reduction_mat
    else:
        fg_mask = torch.ones_like(colors[:, 0]).bool()
    d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    try:
        rins = colors[fg_mask][s[:, 0] < m, 0]
        gins = colors[fg_mask][s[:, 1] < m, 1]
        bins = colors[fg_mask][s[:, 2] < m, 2]
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
    except:
        rins = colors
        gins = colors
        bins = colors
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])

    return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)


def get_pca_map_whole_volume(
    feature_map: torch.Tensor,
    img_size,
    interpolation="bicubic",
    return_pca_stats=False,
    pca_stats=None,
    remove_first_component=False,
):
    """
    feature_map: (num_frames, h, w, C) is the feature map of a single image.
    """
    # print(feature_map.shape)
    if feature_map.shape[0] != 1:
        # make it (1, num_frames, h, w, C)
        feature_map = feature_map[None]
    if pca_stats is None:
        reduct_mat, color_min, color_max = get_robust_pca(
            feature_map.reshape(-1, feature_map.shape[-1]),
            remove_first_component=remove_first_component,
        )
    else:
        reduct_mat, color_min, color_max = pca_stats
    pca_color = feature_map @ reduct_mat
    pca_color = (pca_color - color_min) / (color_max - color_min)
    pca_color = pca_color.clamp(0, 1)
    resized_pca_colors = []
    for i in range(pca_color.shape[1]):
        resized_pca_color = F.interpolate(
            pca_color[:, i, :, :, :].permute(0, 3, 1, 2),
            size=img_size,
            mode=interpolation,
        ).permute(0, 2, 3, 1)
        resized_pca_colors.append(resized_pca_color.cpu().numpy().squeeze(0))
    pca_color = np.stack(resized_pca_colors, axis=0)
    if return_pca_stats:
        return pca_color, (reduct_mat, color_min, color_max)
    return pca_color


# %%
with open('configs/pretrain/vitl16.yaml', 'r') as y_file:
    args = yaml.load(y_file, Loader=yaml.FullLoader)

# -- set device
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- META
cfgs_meta = args.get('meta')
use_sdpa = cfgs_meta.get('use_sdpa', False)

# -- MODEL
cfgs_model = args.get('model')
model_name = cfgs_model.get('model_name')
pred_depth = cfgs_model.get('pred_depth')
pred_embed_dim = cfgs_model.get('pred_embed_dim')
uniform_power = cfgs_model.get('uniform_power', True)
use_mask_tokens = cfgs_model.get('use_mask_tokens', True)
zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True)

# -- MASK
cfgs_mask = args.get('mask')

# -- DATA
cfgs_data = args.get('data')
dataset_type = cfgs_data.get('dataset_type', 'videodataset')
mask_type = cfgs_data.get('mask_type', 'multiblock3d')
dataset_paths = cfgs_data.get('datasets', [])
datasets_weights = cfgs_data.get('datasets_weights', None)
if datasets_weights is not None:
    assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset'
batch_size = cfgs_data.get('batch_size')
batch_size = 1
num_clips = cfgs_data.get('num_clips')
num_frames = cfgs_data.get('num_frames')
tubelet_size = cfgs_data.get('tubelet_size')
sampling_rate = cfgs_data.get('sampling_rate')
duration = cfgs_data.get('clip_duration', None)
crop_size = cfgs_data.get('crop_size', 224)
patch_size = cfgs_data.get('patch_size')
pin_mem = cfgs_data.get('pin_mem', False)
num_workers = cfgs_data.get('num_workers', 1)
filter_short_videos = cfgs_data.get('filter_short_videos', False)
decode_one_clip = cfgs_data.get('decode_one_clip', True)
log_resource_util_data = cfgs_data.get('log_resource_utilization', False)

eval_num_segments = 1
attend_across_segments = False
world_size = 1
rank = 0

# %%
train_data_path = ['lol.csv']
# train_data_path = ['/storage_bizon/naravich/Unlabeled_OCT_videos/Unlabel_OCT_Video.csv']
data_loader = make_dataloader(
        dataset_type=dataset_type,
        root_path=train_data_path,
        resolution=crop_size,
        frames_per_clip=num_frames,
        frame_step=sampling_rate,
        eval_duration=duration,
        num_segments=eval_num_segments if attend_across_segments else 1,
        num_views_per_segment=1,
        allow_segment_overlap=True,
        batch_size=batch_size,
        world_size=world_size,
        rank=rank,
        training=False)

for data in data_loader:
    clips, masks_enc, masks_pred = data
    break

# %%
clips[0][0].shape
min_val = clips[0][0][0].permute(1, 2, 3, 0)[0].numpy().min()
max_val = clips[0][0][0].permute(1, 2, 3, 0)[0].numpy().max()
img = (clips[0][0][0].permute(1, 2, 3, 0)[0].numpy() - min_val) / (max_val - min_val)
print(img.min(), img.max())
plt.imshow(img)

# %%
encoder, predictor = init_video_model(
    uniform_power=uniform_power,
    use_mask_tokens=use_mask_tokens,
    num_mask_tokens=len(cfgs_mask),
    zero_init_mask_tokens=zero_init_mask_tokens,
    device=device,
    patch_size=patch_size,
    num_frames=num_frames,
    tubelet_size=tubelet_size,
    model_name=model_name,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_embed_dim=pred_embed_dim,
    use_sdpa=use_sdpa,
)

# %%
checkpoint = torch.load('vitl16.pth.tar', map_location='cpu')
# checkpoint = torch.load('vith16.pth.tar', map_location='cpu')
print(checkpoint.keys())
new_encoder_state_dict = {}
pretrained_dict = checkpoint['target_encoder']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
# pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
encoder.load_state_dict(pretrained_dict)

# %%
x = encoder(clips[0][0].to(device))

output_of_vjepa = x
print('output_of_vjepa:', x.shape)
print('input shape:', clips[0][0].shape)
B, N, D = x.shape
B, C, FRAMES, H, W = clips[0][0].shape
# Patch = (tubelet_size, patch_size, patch_size)
N_FRAMES = FRAMES // tubelet_size
N_H = H // patch_size
N_W = W // patch_size

print(f'Thus, N feature ({output_of_vjepa.shape[1]}) is calcuated from', H * W * FRAMES / tubelet_size / patch_size / patch_size)

# %%
image_size = (crop_size, crop_size)
volumne_pca_map =  get_pca_map_whole_volume(x.detach().reshape(batch_size, N_FRAMES, N_H, N_W, D), image_size, interpolation="bilinear", remove_first_component=False)
print(volumne_pca_map.shape)


# %%
axes, fig = plt.subplots(2, 8, figsize=(40, 20))
for i in range(8):
    fig[0, i].imshow(volumne_pca_map[i])

for clip_index in range(8):
    image = clips[0][0][0].permute(1, 2, 3, 0)[clip_index].numpy()
    image = (image - image.min()) / (image.max() - image.min())
    fig[1, clip_index].imshow(image)

# %%

The lol.csv which I downloaded from https://www.kaggle.com/datasets/ipythonx/ssv2test?resource=download

/home/naravich/projects/jepa/100972.webm 0

Crashes after first Epoch

I´m trying to get jepa to work on Colab, but for some reason it does End/Crash after completing the first Epoch. The output folder is basically empty (one empty csv file)

The pretained model used is vitl16.pth.tar. (https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar)

The dataset used is a bunch of mp4 videos (no class_labels / set to 0)

Could you give me some pointers on how to possibly debug this?

Environment:
Colab Pro, tried it with the A100 and V100.

Start of the training with:

!python -m evals.main --fname (my modified vith16_k400_16x8x3.yaml with a small dataset) --devices cuda:0

Output:

INFO:root:called-params /content/jepa/xxx-mini.yaml
INFO:root:loaded params...
{   'data': {   'dataset_train': '/content/jepa/xxx-train-mini.csv',
                'dataset_type': 'VideoDataset',
                'dataset_val': '/content/jepa/xxx-val-mini.csv',
                'frame_step': 4,
                'frames_per_clip': 16,
                'num_classes': 100,
                'num_segments': 8,
                'num_views_per_segment': 3},
    'eval_name': 'video_classification_frozen',
    'nodes': 1,
    'optimization': {   'attend_across_segments': True,
                        'batch_size': 4,
                        'final_lr': 0.0,
                        'lr': 0.001,
                        'num_epochs': 20,
                        'resolution': 224,
                        'start_lr': 0.001,
                        'use_bfloat16': True,
                        'warmup': 0.0,
                        'weight_decay': 0.01},
    'pretrain': {   'checkpoint': 'vitl16.pth.tar',
                    'checkpoint_key': 'target_encoder',
                    'clip_duration': None,
                    'folder': '/content/jepa/',
                    'frames_per_clip': 16,
                    'model_name': 'vit_large',
                    'patch_size': 16,
                    'tight_silu': False,
                    'tubelet_size': 2,
                    'uniform_power': True,
                    'use_sdpa': True,
                    'use_silu': False,
                    'write_tag': 'jepa'},
    'resume_checkpoint': False,
    'tag': 'xxx2',
    'tasks_per_node': 8}
INFO:root:Running... (rank: 0/1)
INFO:root:Running evaluation: video_classification_frozen
INFO:root:Initialized (rank/world-size) 0/1
INFO:root:Loading pretrained model from /content/jepa/vitl16.pth.tar
VisionTransformer(
  (patch_embed): PatchEmbed3D(
    (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  )
  (blocks): ModuleList(
    (0-23): 24 x Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
INFO:root:loaded pretrained model with msg: <All keys matched successfully>
INFO:root:loaded pretrained encoder from epoch: 300
 path: /content/jepa/vitl16.pth.tar
INFO:root:VideoDataset dataset created
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
INFO:root:VideoDataset unsupervised data loader created
Making EvalVideoTransform, multi-view
INFO:root:VideoDataset dataset created
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
INFO:root:VideoDataset unsupervised data loader created
INFO:root:Dataloader created... iterations per epoch: 1076
INFO:root:Using AdamW
INFO:root:Epoch 1
INFO:root:[    0] 0.000% (loss: 4.641) [mem: 3.13e+03]
INFO:root:[   20] 90.476% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   40] 95.122% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   60] 96.721% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[   80] 97.531% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  100] 98.020% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  120] 98.347% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  140] 98.582% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  160] 98.758% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  180] 98.895% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  200] 99.005% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  220] 99.095% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  240] 99.170% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  260] 99.234% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  280] 99.288% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  300] 99.336% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  320] 99.377% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  340] 99.413% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  360] 99.446% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  380] 99.475% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  400] 99.501% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  420] 99.525% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  440] 99.546% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  460] 99.566% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  480] 99.584% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  500] 99.601% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  520] 99.616% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  540] 99.630% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  560] 99.643% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  580] 99.656% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  600] 99.667% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  620] 99.678% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  640] 99.688% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  660] 99.697% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  680] 99.706% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  700] 99.715% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  720] 99.723% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  740] 99.730% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  760] 99.737% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  780] 99.744% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  800] 99.750% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  820] 99.756% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  840] 99.762% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  860] 99.768% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  880] 99.773% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  900] 99.778% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  920] 99.783% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  940] 99.787% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  960] 99.792% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[  980] 99.796% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1000] 99.800% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1020] 99.804% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1040] 99.808% (loss: 0.000) [mem: 3.24e+03]
INFO:root:[ 1060] 99.811% (loss: 0.000) [mem: 3.24e+03]
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 88 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

numpy import error

cannot run the v-jepa attentive probe checkpoint for my custom task, Seems to be some issue with the environment .
Screenshot from 2024-03-11 15-26-19

This error is a numpy import error. On further investigation i found that the environment that i am using "jepa" does not have it's numpy so python is looking for numpy in other conda env by name of "mmaction",which does have numpy installed but is failing to import above module.can someone please explain the issue and possibly any fix for the same.
PS: i cloned the repo and then installed ran 'pip install -r requirements.txt'

ValueError: Default process group has not been initialized, please make sure to call init_process_group

First of all, thanks for providing this code 😄

tl;dr

I am getting ValueError when trying to run eval on iNat21 dataset with python -m evals.main --fname configs/evals/vitl16_inat.yaml --devices cuda:0 and running out of ideas how to fix it.

Config values

  • I try to run eval, using single GPU, local machine
  • dataset I use is iNaturalist-2021
  • configs\evals\vith16_inat.yaml look like this:
nodes: 8
tasks_per_node: 8
tag: inat-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
  root_path: D:\__repos\jepa\data
  image_folder: inat
  num_classes: 10000
  resolution: 224
  dataset_name: iNat21
optimization:
  num_epochs: 20
  batch_size: 16
  weight_decay: 0.001
  lr: 0.001
  start_lr: 0.001
  final_lr: 0.0
  warmup: 0.
  use_bfloat16: true
pretrain:
  model_name: vit_large
  checkpoint_key: target_encoder
  clip_duration: null
  frames_per_clip: 16
  tubelet_size: 2
  uniform_power: true
  use_sdpa: true
  use_silu: false
  tight_silu: false
  patch_size: 16
  folder: D:\__repos\jepa\models
  checkpoint: vitl16.pth.tar  # name of pretrained model file inside folder
  write_tag: jepa
  • packages' versions:
Package            Version
------------------ ------------
certifi            2024.2.2
charset-normalizer 3.3.2
colorama           0.4.6
filelock           3.9.0
fsspec             2024.3.1
huggingface-hub    0.22.2
idna               3.7
Jinja2             3.1.2
MarkupSafe         2.1.3
mpmath             1.3.0
networkx           3.2.1
numpy              1.26.3
packaging          24.0
pillow             10.2.0
pip                22.0.4
PyYAML             6.0.1
requests           2.31.0
safetensors        0.4.3
setuptools         58.1.0
sympy              1.12
timm               0.9.16
torch              2.2.2+cu118
torchvision        0.17.2+cu118
tqdm               4.66.2
typing_extensions  4.8.0
urllib3            2.2.1

I have tried

  • as far as I understand, mentioned problem is caused by torch.distributed being available, but not initialized, but I haven't been able to pinpoint where this happens
  • I've run this little 'checklist' from SO and got
>>> import torch
>>> torch.cuda.is_available()
True
>>> torch.cuda.device_count()
1
>>> torch.cuda.current_device()
0
>>> torch.cuda.get_device_name(0)
'NVIDIA GeForce RTX 3060'
  • this pytorch forum post suggest wrong usage of DistributedDataParallel is the root cause, but I haven't found it in the repo
  • this GitHub issue suggested SyncBatchNorm behaving in unexpected way, when running on single GPU, but this has already been fixed in this PR
  • this problem also seems similar to this and this issues
  • I've also tried commenting-out this lines:
    world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
    logger.info(f'Running... (rank: {rank}/{world_size})')

in evals.main, to avoid using of init_distributed function

Full stacktrace

(venv) PS D:\__repos\jepa> python -m evals.main --fname configs/evals/vitl16_inat.yaml
INFO:root:called-params configs/evals/vitl16_inat.yaml
INFO:root:loaded params...
{   'data': {   'dataset_name': 'iNat21',
                'image_folder': 'inat',
                'num_classes': 10000,
                'resolution': 224,
                'root_path': 'D:\\__repos\\jepa\\data'},
    'eval_name': 'image_classification_frozen',
    'nodes': 8,
    'optimization': {   'batch_size': 16,
                        'final_lr': 0.0,
                        'lr': 0.001,
                        'num_epochs': 20,
                        'start_lr': 0.001,
                        'use_bfloat16': True,
                        'warmup': 0.0,
                        'weight_decay': 0.001},
    'pretrain': {   'checkpoint': 'vitl16.pth.tar',
                    'checkpoint_key': 'target_encoder',
                    'clip_duration': None,
                    'folder': 'D:\\__repos\\jepa\\models',
                    'frames_per_clip': 16,
                    'model_name': 'vit_large',
                    'patch_size': 16,
                    'tight_silu': False,
                    'tubelet_size': 2,
                    'uniform_power': True,
                    'use_sdpa': True,
                    'use_silu': False,
                    'write_tag': 'jepa'},
    'resume_checkpoint': False,
    'tag': 'inat-16f',
    'tasks_per_node': 8}
D:\__repos\jepa\venv\lib\site-packages\torch\distributed\distributed_c10d.py:608: UserWarning: Attempted to get default timeout for nccl backend, but NCCL support is not compiled
  warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled")
INFO:root:Rank: 0. Distributed training not available Distributed package doesn't have NCCL built in
INFO:root:Running... (rank: 0/1)
INFO:root:Running evaluation: image_classification_frozen
INFO:root:SLURM vars not set (distributed training not available)
INFO:root:Initialized (rank/world-size) 0/1
INFO:root:Loading pretrained model from D:\__repos\jepa\models\vitl16.pth.tar
VisionTransformer(
  (patch_embed): PatchEmbed3D(
    (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  )
  (blocks): ModuleList(
    (0-23): 24 x Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
)
INFO:root:loaded pretrained model with msg: <All keys matched successfully>
INFO:root:loaded pretrained encoder from epoch: 300
 path: D:\__repos\jepa\models\vitl16.pth.tar
INFO:root:implementing auto-agument strategy
INFO:root:data-path D:\__repos\jepa\data\inat\train/
INFO:root:Initialized ImageFolder
INFO:root:ImageFolder dataset created
INFO:root:ImageFolder unsupervised data loader created
INFO:root:data-path D:\__repos\jepa\data\inat\val/
INFO:root:Initialized ImageFolder
INFO:root:ImageFolder dataset created
INFO:root:ImageFolder unsupervised data loader created
INFO:root:Dataloader created... iterations per epoch: 31250
INFO:root:Using AdamW
Process Process-1:
Traceback (most recent call last):
  File "C:\Users\Maciek\AppData\Local\Programs\Python\Python310\lib\multiprocessing\process.py", line 315, in _bootstrap
    self.run()
  File "C:\Users\Maciek\AppData\Local\Programs\Python\Python310\lib\multiprocessing\process.py", line 108, in run    self._target(*self._args, **self._kwargs)
  File "D:\__repos\jepa\evals\main.py", line 57, in process_main
    eval_main(params['eval_name'], args_eval=params)
  File "D:\__repos\jepa\evals\scaffold.py", line 22, in main
    return importlib.import_module(f'evals.{eval_name}.eval').main(
  File "D:\__repos\jepa\evals\image_classification_frozen\eval.py", line 201, in main
    classifier = DistributedDataParallel(classifier, static_graph=True)
  File "D:\__repos\jepa\venv\lib\site-packages\torch\nn\parallel\distributed.py", line 731, in __init__
    self.process_group = _get_default_group()
  File "D:\__repos\jepa\venv\lib\site-packages\torch\distributed\distributed_c10d.py", line 977, in _get_default_group
    raise ValueError(
ValueError: Default process group has not been initialized, please make sure to call init_process_group.

Question about the mask sampling

Hi, I read the paper JEPA and it is an effective way to learn temporal information better than other works like VideoMAE and UMT.

I have a question about the mask sampling.

To be clear, I do not mean to review or criticize the paper, but I want to reproduce the work exactly.

Question 01) When I instantiate a mask generator and then sample a mask, it sometimes masks only the first N frames.

For example, the source code below describes the situation.

mg = BlockMaskGenerator(aspect_ratio=(0.75, 1.5), npred=8, spatial_pred_mask_scale=(0.15, 0.15), temporal_pred_mask_scale=(1., 1.), max_context_frames_ratio=1.0, image_size=(64, 64), num_frames=2, patch_size=(16, 16), temporal_stride=1)
mask_enc, mask_pred = mg(16)
print(mask_enc)

it outputs

tensor([[ 6,  7,  8, 11, 12, 13, 22, 23],
        [ 2,  3,  4,  5, 10, 11, 15, 18],
        [ 0,  3,  6,  7,  8, 11, 15, 16],
        [ 4,  5, 12, 13, 20, 21, 28, 29],
        [ 2,  3, 10, 11, 12, 13, 14, 15],
        [ 3,  4,  5, 12, 13, 14, 15, 19],
        [ 3,  8, 12, 13, 14, 15, 19, 24],
        [ 0,  3,  8,  9, 15, 16, 19, 24],
        [ 7,  8, 11, 12, 13, 23, 24, 27],
        [ 4,  7,  8, 12, 13, 20, 23, 24],
        [ 2,  3,  4,  8, 18, 19, 20, 24],
        [ 3,  4,  8,  9, 10, 11, 19, 20],
        [ 0,  7,  8,  9, 10, 11, 12, 15],
        [ 0,  1,  4,  7, 11, 12, 13, 16],
        [ 6,  7, 11, 12, 15, 22, 23, 27],
        [ 4,  5,  8,  9, 10, 11, 14, 15]])

In some cases like mask_enc[-1] and mask_enc[-4], the mask is applied only to the first frame.
(There are 2 frames and 16 patches for each frame, then, the indices of [[ 4, 5, 8, 9, 10, 11, 14, 15]] can mask the first frame only -- because the index under 16 is included in the first frame.)

In this case, for some batches, the model seems to use the part of the frames (ex. 4 masked frames out of 8 frames) and is required to reconstruct the entire patches only with first some patches in some frames. (ex. reconstruct 8 frames using 4 masked frames)

Is my analysis correct? If so, it might not be the same as the description of the paper that says the mask is the same for all frames.

3D Multi-Block Masking. We use a simple 3D extension of the block masking strategy employed
for images (Bao et al., 2021). Given a video, we sample several (possibly overlapping) spatially
continuous blocks with various aspect ratios and take their union to construct a single mask. This
spatial mask is then repeated across the entire temporal dimension. Masking a large continuous
block that covers the full temporal dimension limits information leakage due to the spatial and
temporal redundancy of videos, and results in a harder prediction task (Tong et al., 2022).

In this case, the masking strategy does not work as the intention to limit information leakage.

Question 02) The sum of the visible and invisible masks seems not to be the same as the total number of patches.

When I print the shape of each mask, I get the output like below:

print(mask_enc.shape)
print(mask_pred.shape)

torch.Size([16, 8])
torch.Size([16, 16])

There are 32 patches (2 frames * 16 patches for each frame = 32) but the sum of the lengths is less than the total patch counts.

Discussion

The second question might not be that problematic. It uses the part of the visible patches for each sample to reconstruct the part of the input video. Because partial reconstruction in MAE is shown to be effective in the paper [1]

[1] CrossMAE: Rethinking Patch Dependence for Masked Autoencoders

Approach (if the analysis is correct and the behavior is not intended)

However, the first question can affect the performance because the masking method aims to block the information leakage between the frames, specifically, preventing the model from copying the near patches at the different frames.

To resolve the problem, I think the masking block should be sampled for a single frame and repeated along the time axis with an offset (the number of patches in each frame).

I hope the discussion improves the clarity of the source code and the paper.

Thanks.

Update

The source code below can be a way to fix the mask sampling method.

        collated_masks_pred, collated_masks_enc = [], []
        min_keep_enc = min_keep_pred = self.duration * self.height * self.width
        for _ in range(batch_size):

            empty_context = True
            while empty_context:

                mask_e = torch.ones((1, self.height, self.width), dtype=torch.int32)
                for _ in range(self.npred):
                    mask_e *= self._sample_block_mask(p_size)
                mask_e = mask_e.flatten()

                mask_p = torch.argwhere(mask_e == 0).squeeze()
                mask_e = torch.nonzero(mask_e).squeeze()

                empty_context = (len(mask_e) == 0)
                if not empty_context:
                    min_keep_pred = min(min_keep_pred, len(mask_p))
                    min_keep_enc = min(min_keep_enc, len(mask_e))
                    collated_masks_pred.append(mask_p)
                    collated_masks_enc.append(mask_e)

        if self.max_keep is not None:
            min_keep_enc = min(min_keep_enc, self.max_keep)

        # --
        return self._truncate_mask(collated_masks_enc, min_keep_enc), self._truncate_mask(collated_masks_pred, min_keep_pred)
    
    def _truncate_mask(self, masks, min_keep):
        result = []
        for cm in masks:
            # choice min_keep items randomly
            idx = torch.randperm(len(cm))[:min_keep]
            cm = cm[idx]
            tmp = torch.zeros((1, self.height, self.width), dtype=torch.int32)
            tmp.flatten()[cm] = 1
            tmp = tmp.repeat(self.duration, 1, 1)
            tmp = torch.nonzero(tmp.flatten()).squeeze()
            result.append(tmp)
        return torch.utils.data.default_collate(result)

For the sanity check, I run the code without "tmp = torch.nonzero(tmp.flatten()).squeeze()".

The outputs are like:

image

load the pretrained model in kaggle to interact directly with it

i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best
any help is much appreciated
data set required in kaggle : imagenet-1k-resized-256

i coped the relevant pieces from the eval script
the code as follows it takes a min to run mostly the downlad

"""

get the repo in cell 1

!git clone https://github.com/facebookresearch/jepa.git
import os
os.chdir('/kaggle/working/jepa')
!pip install .

!wget https://dl.fbaipublicfiles.com/jepa/vitl16/in1k-probe.pth.tar
!wget https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar

config for the model i want

import yaml
with open('/kaggle/working/jepa/configs/evals/vitl16_in1k.yaml', 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)

params['pretrain']['folder'] = '/kaggle/working/jepa'
params['pretrain']['checkpoint'] = 'vitl16.pth.tar'

loading the model

import jepa.src.models.vision_transformer as vit
import torch

def load_pretrained(
encoder,
pretrained,
checkpoint_key='target_encoder'
):
print(f'Loading pretrained model from {pretrained}')
checkpoint = torch.load(pretrained, map_location='cpu')
try:
pretrained_dict = checkpoint[checkpoint_key]
except Exception:
pretrained_dict = checkpoint['encoder']

pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
for k, v in encoder.state_dict().items():
    if k not in pretrained_dict:
        print(f'key "{k}" could not be found in loaded state dict')
    elif pretrained_dict[k].shape != v.shape:
        print(f'key "{k}" is of different shape in model and loaded state dict')
        pretrained_dict[k] = v
msg = encoder.load_state_dict(pretrained_dict, strict=False)
print(f'loaded pretrained model with msg: {msg}')
print(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}')
del checkpoint
return encoder

def init_model(
device,
pretrained,
model_name,
patch_size=16,
crop_size=224,
# Video specific parameters
frames_per_clip=16,
tubelet_size=2,
use_sdpa=False,
use_SiLU=False,
tight_SiLU=True,
uniform_power=False,
checkpoint_key='target_encoder'
):
encoder = vit.dict[model_name](
img_size=crop_size,
patch_size=patch_size,
num_frames=frames_per_clip,
tubelet_size=tubelet_size,
uniform_power=uniform_power,
use_sdpa=use_sdpa,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
)
if frames_per_clip > 1:
def forward_prehook(module, input):
input = input[0] # [B, C, H, W]
input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1)
return (input)

    encoder.register_forward_pre_hook(forward_prehook)

encoder.to(device)

encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key)
return encoder

args_eval = params

args_pretrain = args_eval.get('pretrain')

checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder')
model_name = args_pretrain.get('model_name', None)
patch_size = args_pretrain.get('patch_size', None)
pretrain_folder = args_pretrain.get('folder', None)
ckp_fname = args_pretrain.get('checkpoint', None)
tag = args_pretrain.get('write_tag', None)
use_sdpa = args_pretrain.get('use_sdpa', True)
use_SiLU = args_pretrain.get('use_silu', False)
tight_SiLU = args_pretrain.get('tight_silu', True)
uniform_power = args_pretrain.get('uniform_power', False)
pretrained_path = os.path.join(pretrain_folder, ckp_fname)

Optional [for Video model]:

tubelet_size = args_pretrain.get('tubelet_size', 2)
frames_per_clip = args_pretrain.get('frames_per_clip', 1)

args_data = args_eval.get('data')
resolution = args_data.get('resolution', 224)
num_classes = args_data.get('num_classes')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = init_model(
crop_size=resolution,
device=device,
pretrained=pretrained_path,
model_name=model_name,
patch_size=patch_size,
frames_per_clip=1,
tubelet_size=1,
uniform_power=uniform_power,
checkpoint_key=checkpoint_key,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
use_sdpa=use_sdpa)

encoder.eval()
for p in encoder.parameters():
p.requires_grad = False

print(encoder)

#loading the classifier
from jepa.src.models.attentive_pooler import AttentiveClassifier

classifier = AttentiveClassifier(
embed_dim=encoder.embed_dim,
num_heads=encoder.num_heads,
depth=1,
num_classes=num_classes
).to(device)

checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu'))
pretrained_dict = checkpoint['classifier']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}

print(classifier)

msg = classifier.load_state_dict(pretrained_dict)
print(msg)

evaluating

from PIL import Image
from io import BytesIO
import pickle
import os
import pandas as pd

import torch
from torchvision import transforms

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])

parquet_file_path = "/kaggle/input/imagenet-1k-resized-256/data/train-00001-of-00052-886eb11e764e42fe.parquet"
df = pd.read_parquet(parquet_file_path)
print(df.shape)

file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl"
with open(file_path, "rb") as f:
classes = pickle.load( f)

for idx ,row in df.sample(n=15).iterrows():
img = (Image.open(BytesIO(row['image']['bytes'])))
outs = classifier(encoder(transform(img).unsqueeze(0).to(device)))
values, indices = torch.topk( outs, 10 )
display(img)
print(f'real {row["label"]}')
for n ,i in enumerate(indices[0]) :
print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')

"""

Release of "conditional diffusion decoder"

I wonder will you release the conditional diffusion decoder architecture or code of the evaluation of v-jepa model?
I would also like to evaluate my training result after training on my custom dataset.

Thanks.

Evals not eval for evaluation script (typo)

In the READme "Launching evaluations" section there is a typo in the three following scripts:

python -m evals.main \
  --fname configs/eval/vitl16_in1k.yaml \
  --devices cuda:0 cuda:1 cuda:2
python -m evals.main_distributed \
  --fname configs/eval/vitl16_in1k.yaml \
  --folder $path_to_save_stderr_and_stdout \
  --partition $slurm_partition
python -m evals.main_distributed \
  --fname configs/eval/vitl16_k400.yaml \
  --folder $path_to_save_stderr_and_stdout \
  --partition $slurm_partition

for the fname argument eval should be evals.

video_dataset: Random sample assignment

In the file video_dataset.py, the getitem not make sense to me:

    def __getitem__(self, index):
        sample = self.samples[index]

        # Keep trying to load videos until you find a valid sample
        loaded_video = False
        while not loaded_video:
            buffer, clip_indices = self.loadvideo_decord(sample)  # [T H W 3]
            loaded_video = len(buffer) > 0
            if not loaded_video:
                index = np.random.randint(self.__len__())
                sample = self.samples[index]

        # Label/annotations for video
        label = self.labels[index]

        def split_into_clips(video):
            """ Split video into a list of clips """
            fpc = self.frames_per_clip
            nc = self.num_clips
            return [video[i*fpc:(i+1)*fpc] for i in range(nc)]

        # Parse video into frames & apply data augmentations
        if self.shared_transform is not None:
            buffer = self.shared_transform(buffer)
        buffer = split_into_clips(buffer)
        if self.transform is not None:
            buffer = [self.transform(clip) for clip in buffer]

        return buffer, label, clip_indices

Particularly, the following:

            if not loaded_video:
                index = np.random.randint(self.__len__())
                sample = self.samples[index]

In the current setup (at least in eval), samples are file paths to videos. So, here we´re replacing the video output with a random other video and returning it as the video at the current index with the label for the current index?

Worst case:
This could mess up validation (if the labels are used)
Best case:
Random double videos

Or maybe I´m missing something?

Evals, not Eval

Hi, I believe the config files of the Launching Evaluations section in the Readme.md file belong to the Evals folder rather than Eval. :)

typo/bug in pooler?

Probably enumerate(self.blocks, 1) ?

in attentive_pooler.py

for layer_id, layer in enumerate(1, self.blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp.fc2.weight.data, layer_id + 1)

few questions on feature detection

Hi, I just need some clarification if I have things correct, this is based on video eval. The dataloader divides a single video clip from a single class into 16 frames and resizes to 224? [1, 3, 16, 224, 224]. The two lists wrapping this tensor are clips and views. Views being the number of splices for a videos in a class, and the number of clips being different classes? How are these two different if I am mistaken.

The forward returns 1568 which is 14x14 x [clip length/2], which implies the delta/feature_changes from frame_a to frame_b? So it really wouldn't matter if we make positional embeddings to adjust the frame length from 2 to 200? There is no cross attention implementation yet for video inference so I am assuming it doesn't matter on the number of views per clip or number of classes passed through at any given time?

Thank you kindly.

Use of Conditional (possibly latent) variables (z) in predictor from original JEPA paper.

Hey Folks,

I have been looking through V-JEPA and its predecessors, and I am trying to find if V-JEPA is making use of the conditional variables in the predictor, as I am struggling to tell myself from the code (I am relatively new to ML). There is limited mentions of it in the I-JEPA and V-JEPA papers, so I was wondering if it is something for future research.

Thanks,
Adam

long video

Dear V-Jepa team,

Thank you for sharing this great work; I really enjoyed it.

If I understand correctly, the model is only trained with a video of 16 frames (after frame skipping, around 3s). Does it work with long videos with long frames (>60 frames, >10s or >30s)? Or do I need to fine-tune it?

Thank you for your help.

Best Wishes,

Zongze

Crashes after first epoch because of leaked semaphores

Hi, I am running an evaluation on a small dataset (train dataset of 22 labeled videos and val dataset of 2 labeled videos.). It crashes after the first epoch after my RAM gets maxed out.

Error received: /opt/conda/envs/jepa-p10/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 88 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '

Here is config file:

{   'data': {   'dataset_train': '/home/ubuntu/dev/jepa/val_dataset.csv',
                'dataset_type': 'VideoDataset',
                'dataset_val': '/home/ubuntu/dev/jepa/train_dataset.csv',
                'frame_step': 4,
                'frames_per_clip': 16,
                'num_classes': 2,
                'num_segments': 2,
                'num_views_per_segment': 3},
    'eval_name': 'video_classification_frozen',
    'nodes': 1,
    'optimization': {   'attend_across_segments': True,
                        'batch_size': 1,
                        'final_lr': 0.0,
                        'lr': 0.001,
                        'num_epochs': 20,
                        'resolution': 224,
                        'start_lr': 0.001,
                        'use_bfloat16': True,
                        'warmup': 0.0,
                        'weight_decay': 0.01},
    'pretrain': {   'checkpoint': 'vitl16.pth.tar',
                    'checkpoint_key': 'target_encoder',
                    'clip_duration': None,
                    'folder': './',
                    'frames_per_clip': 16,
                    'model_name': 'vit_large',
                    'patch_size': 16,
                    'tight_silu': False,
                    'tubelet_size': 2,
                    'uniform_power': True,
                    'use_sdpa': True,
                    'use_silu': False,
                    'write_tag': 'jepa'},
    'resume_checkpoint': False,
    'tag': 'ssv2-16x2x3',
    'tasks_per_node': 1}

Please assist.

How to reproduce video recognition Acc in the Table?

Thank you very much to share your great work!
I tried to reproduce the video recognition results but get very low accuracy.
Can you give me some advices if I missed something? or kindly provide a script which can get Acc in the Table?

I tested the model based on this script: jepa/evals/video_classification_frozen/eval.py, and removed code related to training.
Model & config:
Encoder: vith16.pth.tar
Classifier: vith16-k400-probe.pth.tar
Config: vith16_k400_16x8x3
Example data, first 5 videos of k400 val-set, label is "abseiling":
0wR5jVB-WPk_000417_000427.mp4
3caPS4FHFF8_000036_000046.mp4
3yaoNwz99xM_000062_000072.mp4
6IbvOJxXnOo_000047_000057.mp4
6_4kjPiQr7w_000191_000201.mp4
Resut:
Index: 0 , Predict: 198
Index: 1 , Predict: 198
Index: 2 , Predict: 198
Index: 3 , Predict: 211
Index: 4 , Predict: 198

Label "abseiling" should be 0, accoring to willprice/KINETICS_LABELS.md
So the predictions are all wrong?

Cannot access the config file vitl16_k400.yaml

When I click on the link to download configs/eval/vitl16_k400.yaml, I came across the following error:

404 - page not found
The main branch of jape does not contain the path configs/eval/vitl16_k400.yaml.

FSDP Support

This is a bit of a technical challenge and/or question. Both I-JEPA and V-JEPA use DDP and not FSDP. This puts an inherent cap on the size of models that are used, the size of the GPU memory.

I'm wondering if there is any thought being put into the support of JEPA with FSDP. In my mind, the flow would be to

  1. Ensure that the model sharding of the target and context encoder is equivalent.
  2. Update only the sharded parameters on a particular node (could even be a performance improvement vs. DDP).
  3. During forward passes, share the locally updated weights to all nodes.

I attempted to implement something like this on my side, though FSDP seems to shard the parameters a bit sporadically, e.g. not following 1. above.

Any suggestions?

Experimental setup for the Low-Shot Frozen Evaluation (Table 7)

Hey, I was wondering what the number of optimization steps (or epochs) and batch sizes are used for the Low-Shot Frozen Evaluation experiment (Table 7 in the V-JEPA paper).

Is there any other hparam different from the experiments that use the full training set?

Share HF Model checkpoints to the Hugging Face Hub 🤗

Hi there,

Congratulations on such a brilliant release. I'm VB, leading the advocacy efforts for open source at Hugging Face. I saw that the model checkpoints are released as CDN URLs. It'd be great if you could upload the model checkpoints over the Hugging Face Hub.

These ckpts can be part of the Facebook org on the Hugging Face Hub: https://huggingface.co/facebook :)

Uploading model checkpoints over on Hugging Face Hub comes with a couple of advantages:

  1. It increases the visibility of the model checkpoints on the Hub.
  2. It makes it easier for people to download the different weights and track the number of downloads too.
  3. It makes it surprisingly easy to version control weights as well.
  4. Here's a quick guide explaining how to upload models over on the Hub: https://huggingface.co/docs/hub/en/models-uploading

In addition to this, you can add support for directly loading the model checkpoints via the huggingface_hub library as well: https://huggingface.co/docs/huggingface_hub/v0.16.3/en/guides/download (it's just 2 lines of code).

Do let me know if you need any assistance with this.

Cheers,
VB

Training Loss Increasing After Initial Decrease with Custom Video Dataset

Hello Everyone,

I've been working with the V-JEPA model for a self-supervised learning project using a custom video dataset. Initially, the training loss decreases as expected, but starts to increase significantly after reaching a minimum. This behavior persists across multiple training sessions with different hyperparameters.

jepa_loss_small_collapse

Configuration:

Data Setup

  • Dataset Type: VideoDataset
  • Batch Size: 24
  • Number of Clips: 1
  • Number of Frames per Clip: 16
  • Tubelet Size: 4
  • Sampling Rate: 2
  • Crop Size: 224
  • Patch Size: 16
  • Memory Pinning: true
  • Number of Workers: 8
  • Filter Short Videos: false

Data Augmentation

  • Auto Augment: false
  • Motion Shift: false
  • Random Resize Aspect Ratio: [0.75, 1.35]
  • Random Resize Scale: [0.3, 1.0]
  • Re-probability: 0.0

Loss Configuration

  • Loss Exponent: 1.0
  • Regularization Coefficient: 0.0

Mask Settings

  • Aspect Ratio: [0.75, 1.5]
  • Number of Blocks: [8, 2]
  • Spatial Scale: [0.15, 0.7]
  • Temporal Scale: [1.0, 1.0]
  • Max Temporal Keep: 1.0

Meta Configuration

  • Seed: 234
  • Evaluation Frequency: Every 100 epochs
  • Use SDPA: true
  • Data Type: float16

Model Configuration

  • Model Name: vit_small
  • Predictor Depth: 12
  • Predictor Embedding Dimension: 384
  • Uniform Power: true
  • Use Mask Tokens: true
  • Zero Initialize Mask Tokens: true

Optimization

  • Iterations per Epoch: 300
  • IPE Scale: 1.25
  • Gradient Clipping: 10.0
  • Weight Decay: 0.04
  • Final Weight Decay: 0.4
  • Epochs: 500
  • Warmup: 40
  • Start Learning Rate: 0.0004
  • Learning Rate: 0.000825
  • Final Learning Rate: 1.0e-06
  • Exponential Moving Average: [0.998, 1.0]

Questions:

  1. Has anyone else encountered similar issues when training on custom datasets, particularly with video data?
  2. Are there recommended strategies for adjusting the training regimen or model configuration that might stabilize the loss?
  3. Could this be related to the specific characteristics of video data in the custom dataset that might require different handling or preprocessing?

Any insights or suggestions would be greatly appreciated. Thank you for your support!

Best regards,

@MidoAssran

Is AutoAugment used for videos or not

Hi,
According to the paper, autoaugment is only used for the images classification (with the cross attention head), but in the code AA is on for the videos as well, can you confirm which one is correct?

How to reproduce video recognition Acc in the Table?

Thank you very much to share your great work!
I tried to reproduce the video recognition results but get very low accuracy.
Can you give me some advices if I missed something? or kindly provide a script which can get Acc in the Table?
I tested the model based on this script: jepa/evals/video_classification_frozen/eval.py
Configs:vitl16_ssv2_16x2x3.yaml
nodes: 8
tasks_per_node: 8
tag: ssv2-16x2x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: xx/ssv2_train.csv
dataset_val: xx/ssv2_val.csv
dataset_type: VideoDataset
num_classes: 174
frames_per_clip: 16
num_segments: 1 #2
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 224
batch_size: 16 #4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: xx/JEPA
checkpoint: vitl16.pth
write_tag: jepa

License

Hi,
Great work on V-JEPA! Any plans to release it under an open sourced license?
Thanks!

Question: no release of decoder checkpoint?

Hi, thanks for the great work.

I checked the downloaded checkpoint and find that there is no decoder model checkpoint mentioned in the paper that was used for mapping the feature back to the .

I would like to ask if you have any plan to release that checkpoint? Cause it is also the reason why jepa is powerful for the further researching in the downstream area.

Thanks!

Question about expected finetuning performance

Hello, I read the paper and the performances reported using linear probing are impressive.
My question is: if I were to finetune the model on iNat18 using a linear head, in your experience would the performances be better than a model pretrained on MAE, MAWS, etc?

Thanks and congrats for the work!

A few thoughts on JEPA: task-goal / task-definition

(1) LeCun and his collaborators or doctoral students are all experts, and I greatly admire them.
(2) Technically, my understanding may be incorrect.

Just technical thoughts!!!

I always feel that JEPA is not quite suitable or expansive/leapfrogging enough in terms of task goal/task definition, which leads to the JEPA series algorithms are not enough even if they are optimized well.

(1) Even in the vision alone, JEPA does not explain the representations it learns internally, nor what representations such as world models are. Are they still just distributed weights of ordinary neural networks, or are there special network structures like laten representations, or are there explicit 3D/4D representations? Without delving into the details of JEPA, looking at this network in a general way does not show a significant difference, nor does it provide a special task definition leap for stronger AI like AGI/ASI.
Even at the forefront like JEPA, I believe that even when focusing solely on pure vision tasks, there hasn't been a fundamental breakthrough. I don't think an ideal, powerful vision system should be a one-way, one-shot, one-train-many-infer system similar to LLMs. Each visual processing involves multiple visual recognitions occurring in parallel, alternating and iterating repeatedly before producing the final output.

(2) From the perspective of the perfect task ultimate form of vision, I personally believe it should be, like humans, being able to construct a 3-dimensional world from a single image/a pair of images (with left and right disparity) or video, without needing the camera/observation position, and even a dynamic 4-dimensional world (in most cases, not requiring physical-level precision). Here, it could be laten representation, but it would be better to have an explicit representation (such as point-cloud, surface-mesh, gauss-splatting, ...).

(3) In order to support various high-level applications such as differentiable form prediction/inference/planning based on vision, this laten or explicit representation can be utilized by neural networks. (For example, estimating how moving objects maneuver around a building on the road). Depending on the requirements of more applications, this 3D/4D representation may also need estimated-distance and semantic labels.(human, building; stone, swamp; water, fog,glass ...; old or new; soft or hard; ...)

(4) I want to append some sampels (follow) about vision-AI but not limited on vision-only.

Not questioning the great minds, just pondering what is the ultimate definition of the visual task. After JEPA is done well, how far away are we from it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.