Code Monkey home page Code Monkey logo

Comments (14)

pierrefdz avatar pierrefdz commented on August 27, 2024 3

Thanks, I'll try to have a more thorough look at it in the following days.

Some points that might cause the problem:

  • maybe try switching to SGD with different LRs (try without lr scheduling, it will have a minor impact)
  • remove the gaussian blur augmentation
  • check that the forward function in your MLP does the right thing (that everything has the right dimension, that you average on the time dimension, etc.). My code snippet for feeding the videos is:
with torch.no_grad():    
    B,C,T,H,W = inp.shape
    inp = inp.transpose(1,2).reshape(B*T,C,H,W) # b c t h w -> b t c h w -> b*t c h w
    output = model(inp)
    output = output.reshape(B,T,-1) # b*t d -> b t d
    output = output.mean(dim=-2) # b t d -> b d
output = linear_classifier(output) # b d -> b l

from dinov2.

ccharest93 avatar ccharest93 commented on August 27, 2024 2

for N extracted frames, pass them throught the network and get N CLS tokens. Two methods going forward:

  1. Average CLS tokens embedding over the N frames so that the input to your classification head is 1x Embed_Dim
  2. Concatenate the N CLS token so that the input to your classification head is Nx Embed_Dim

One will train over less parameters because the input dimension is smaller, Two retains more information but has much higher param count in classification head.

You could also do something in the middle where you average the first N/2 CLS token and the last N/2 CLS token giving you input dimension for classification head of 2 x Embed_dim.

from dinov2.

pierrefdz avatar pierrefdz commented on August 27, 2024 2

Also, as mentioned in the paper, concatenation allows to "retain [...] temporal information" compared to average pooling.

How many frames was picked for SSv2 task?

I imagine it is N=8 frames as well, but I understand that you would like confirmation from the authors.

Paper

Hi, author here!
I confirm that what @ccharest93 said is correct, and N=8 in both cases.

The underlying idea is that:

  • For UCF and K-400, the answer of the classification task can very often be obtained from good visual features (and using only one frame gives very high accuracy!). So there is less need of "temporal" information.
  • For SSv2, some labels like "Turning the camera left while filming something" or "Turning the camera right while filming something" need to have "temporal information" and order in the frames to determine the action. Averaging the visual features would lose that.

Remarks:

  • Increasing N gives better performance but increases loading and computing time (and need different hyper-parameters to obtain best performance).
  • If you want to go further, you can use ViViT (https://arxiv.org/abs/2103.15691) with something that looks like the factorized encoder of Fig. 1.

I hope this is useful

from dinov2.

pierrefdz avatar pierrefdz commented on August 27, 2024 2

Closing as answered, thanks for your interest!

from dinov2.

patricklabatut avatar patricklabatut commented on August 27, 2024 1

If you are trying to reproduce the video action recognition results, these are described in the second paragraph of sub-section 7.2 of the paper. These results were obtained with a linear classifier trained on features from a number of evenly spaced frames, without any fine-tuning.

from dinov2.

woctezuma avatar woctezuma commented on August 27, 2024 1

Also, as mentioned in the paper, concatenation allows to "retain [...] temporal information" compared to average pooling.

How many frames was picked for SSv2 task?

I imagine it is N=8 frames as well, but I understand that you would like confirmation from the authors.

Paper

from dinov2.

Batwho avatar Batwho commented on August 27, 2024 1

Problem solved, it was due to a bug in dataloader at val dataloader initialization.
Thank you @pierrefdz and feel free to close this issue.

from dinov2.

steveice avatar steveice commented on August 27, 2024

Thank you for explaining the detailed. Can you provide more details about "For SSv2, we opt for concatenation to retain more temporal information than with feature averaging. "? How many frames was picked for SSv2 task?

from dinov2.

Batwho avatar Batwho commented on August 27, 2024

Could you please also provide the detailed linear classifier structure and related key parameters? I tried using a simple MLP (with just two linear layers) on the UCF dataset but the accuracy is pretty low.

from dinov2.

pierrefdz avatar pierrefdz commented on August 27, 2024

Hi, results from the paper are done using a single layer as linear classifier. What do you mean by "pretty low"? Would you be able to share your implementation and hyper-parameters for the optimization of the layer?

from dinov2.

Batwho avatar Batwho commented on August 27, 2024

Hi!
Thanks for your quick response. The accuracy after 10 epochs I got is around 5%. So I guess there is probably something I didn't do right. I tried a single layer with 384 input dim (using vits14) to 101 classes. The optimizer is Adam with lr=0.001, StelpLR scheduler with step_size =1, and gamma=0.95.

If this setting should make it work, I could then show my code as well.

from dinov2.

Batwho avatar Batwho commented on August 27, 2024

code:

# Dataset Class
class UCFDataset(torch.utils.data.Dataset):
  

    def __init__(self, dataset_dir, subset, video_list_file, frames_per_clip=16):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.video_dir = video_dir
        self.subset=subset
        self.video_list_file = video_list_file
        self.video_list = []
        self.labels = []
        self.indices = []

        for i in [1,2,3]:
            with open(f'{dataset_dir}/{video_list_file}{str(i)}.txt') as video_names_file:
                if self.subset=="train":
                    tempvideo_list,templabels = zip(*(files[:-1].split() for files in video_names_file.readlines()))
                    self.video_list += tempvideo_list
                    self.labels += templabels
                else:
                    tempvideo_list = [files[:-1] for files in video_names_file.readlines()]
                    templabels = [None]
                    self.video_list += tempvideo_list
                    self.labels += templabels
        

        self.frames_per_clip = frames_per_clip

        self.transform = tv.transforms.Compose([
          tv.transforms.GaussianBlur(9, sigma=(0.1, 2.0)),
          tv.transforms.Resize(256,interpolation=tv.transforms.InterpolationMode.BICUBIC),
          tv.transforms.CenterCrop(224),
          tv.transforms.ToTensor(),
          tv.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        videoname = f'{self.video_list[idx]}'
        vid = decord.VideoReader(f'{self.video_dir}/{videoname}', ctx=decord.cpu(0))
        nframes = len(vid)

        # if number of frames of video is less than frames_per_clip, repeat the frames
        if nframes <= self.frames_per_clip:
            idxs = np.arange(0, self.frames_per_clip).astype(np.int32)
            idxs[nframes:] %= nframes

        # else if frames_per_clip is greater, sample uniformly seperated frames
        else:
            idxs = np.linspace(0, nframes-1, self.frames_per_clip)
            idxs = np.round(idxs).astype(np.int32)

        imgs = []
        for k in idxs:
            frame = Image.fromarray(vid[k].asnumpy())
            frame = self.transform(frame)
            imgs.append(frame)
        imgs = torch.stack(imgs)

        # if its train subset, return both the frames and the label 
        if self.subset=="train":
            label = int(self.labels[idx]) - 1    
        # else, for test subset, read the label index
        else:
            with open(f'{dataset_dir}/classInd.txt') as classIndices:
                label=int(classIndices[videoname.split('/')[0]])
        return imgs,label

class MLP(nn.Module):
    
    def __init__(self, dim, inner_dim,n_class,encoder):     #dim would be the output image feature from dinov2                                
        super().__init__()
        # mlp with GELU activation function
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(dim, n_class),
        )

    def forward(self, x):
        # x is [16,8,3,224,224]
        avg = []
        
        for i in range(8):
            xi = x[:,i,:]
            #encode x to [8,384]
            with torch.no_grad():
                e = self.encoder(xi).reshape(x.shape[0],1,384)
            avg.append(e)
        avg = torch.cat(avg,dim=1)    
        avg = reduce(avg, "f t c -> f c",'mean')        #[16,384]
        return self.mlp(avg)


# Dataset
train_val_data = UCFDataset( dataset_dir = dataset_dir, subset="train", video_list_file="trainlist0",frames_per_clip=frames_per_clip)

train_len=int(0.85*len(train_val_data))
train_val_split = [ train_len, len(train_val_data) - train_len ] 

train_data , val_data = random_split(train_val_data,train_val_split)
test_data = UCFDataset( dataset_dir = dataset_dir, subset="test", video_list_file="testlist0" ,frames_per_clip=frames_per_clip)

# Dataloaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=test_batch_size)
test_loader = DataLoader(test_data, batch_size=test_batch_size)

# data loading params
batch_size = 256
test_batch_size = 1
num_workers = 8
pin_memory = True
num_classes=101

dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vits14.to(device)
for param in dinov2_vits14.parameters():
    param.requires_grad= False
model = MLP(384,512,101,dinov2_vits14)
#frames, _ = next(iter(train_loader))
#tb_writer.add_graph(model, frames)
model.to(device)

# define the loss and optimizers
loss_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)




# training step for every epoch
def train_step(loader,epoch,):
    
    model.train()
    total_epoch_loss=0
    
    for batch_id, (video_data,labels) in enumerate(loader):

        # video_data,labels = video_data.to(device), labels.to(device)
        video_data,labels = video_data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        prediction = model(video_data)

        loss = loss_criterion(prediction,labels)
        total_epoch_loss += loss.item()

        loss.backward()
        
        optimizer.step()

        del video_data
        del labels

        gc.collect()
        
        #tb_writer.add_scalar("Train/Loss",loss.item(),((len(loader))*(epoch-1))+batch_id)
 
        print(f"\n[Train Epoch]: {epoch} Train Loss: {loss.item()}")
    return total_epoch_loss


# validation step for every epoch
def val_step(loader,epoch=None):

    model.eval()
    total_loss=0
    corrects=0
    
    with torch.no_grad():
        for batch_id, (video_data,labels) in enumerate(loader):

            video_data,labels = (video_data).to(device), labels.to(device)

            prediction = model(video_data)
            
            loss = loss_criterion(prediction,labels)
            total_loss += loss.item()
            corrects+= (torch.argmax(prediction,dim=1)==labels).sum()
    
    accuracy = corrects/(len(loader)*batch_size)
    
    print(f"\n[Val Epoch]: {epoch} , Accuracy: {accuracy}, Valid Loss: {loss.item()}")


    return accuracy

# Driving train test loop
for epoch in tqdm(range(1,epochs+1)):
    train_step(train_loader, epoch)
    val_step(val_loader, epoch)
    scheduler.step()
    torch.save(model,"dino_model.pt")

from dinov2.

Batwho avatar Batwho commented on August 27, 2024

Thanks, I'll try to have a more thorough look at it in the following days.

Some points that might cause the problem:

  • maybe try switching to SGD with different LRs (try without lr scheduling, it will have a minor impact)
  • remove the gaussian blur augmentation
  • check that the forward function in your MLP does the right thing (that everything has the right dimension, that you average on the time dimension, etc.). My code snippet for feeding the videos is:
with torch.no_grad():    
    B,C,T,H,W = inp.shape
    inp = inp.transpose(1,2).reshape(B*T,C,H,W) # b c t h w -> b t c h w -> b*t c h w
    output = model(inp)
    output = output.reshape(B,T,-1) # b*t d -> b t d
    output = output.mean(dim=-2) # b t d -> b d
output = linear_classifier(output) # b d -> b l

Could you please also share your epoch, batch size, and lr? I guess it might be the reason that I haven't trained enough time due to limited GPU RAM.

from dinov2.

pierrefdz avatar pierrefdz commented on August 27, 2024

Thanks for keeping me updated on this @Batwho.
Don't hesitate to re-open if you need anything else.

from dinov2.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.