Code Monkey home page Code Monkey logo

Comments (4)

rdroste avatar rdroste commented on June 5, 2024 1

Exactly, you wouldn't need to keep all the domain-specific parameters when you fine-tune the model. When initializing the UNISAL model class I would set the sources=(my_source,) where my_source is the dataset from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') that matches your target dataset most closely. It might be worth trying out all of them. For a static target dataset the best choice would be sources=('SALICON',). For video data it's difficult to say a priory but sources=('DHF1K',) might be a good default option since DHF1K is the most varied of the video datasets. Afterwards, call the model forward with pred = model(x, source='my_source'). Hope that makes sense.

from unisal.

rdroste avatar rdroste commented on June 5, 2024

Hi Ekta, thanks for your interest in our work. A minimal example for fine-tuning the model is a good idea, I'll try to find some time soon the upload one.

However, one difficulty with a general fine-tuning example might be that the optimal fine-tuning method (learning rate, learning rate schedule, batch size, freezing different parts of the network, etc., etc.) really depends on the target dataset. Therefore you could manually load the UNISAL model and plug it into your own training script.

To load the pretrained model you can run something like:

import unisal
model = unisal.model.UNISAL()
model.load_best_weights('unisal/training_runs/pretrained_unisal')

If you want to load the model for one of the training datasets only, you could also run (untested):

my_source = <insert whichever dataset matches your data most closely from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)

(Instead of using strict=False, which can fail silently, you could also remove the weights with keys 'rnn', 'post_rnn' and keys containing 'DHF1K', 'Hollywood' or 'UCFSports' from the state dict)

If you want to use the model for static data only, you can reduce the model size by loading it without the GRU RNN by running (untested):

my_source = 'SALICON'
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)

In your training code you can then call the model with

# ... your code here
my_source = <one of ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
static= <True or False>
prediction = model(training_batch, source=my_source, static=static)

from unisal.

prashnani avatar prashnani commented on June 5, 2024

thanks for your response @rdroste ! Will wait for your minimal example. 👍 Makes sense that with a new dataset, there would be work involved with hyper-parameter tuning.

For plugging unisal into my own training script:
It would be great to know which of the components of model.py are needed when training for just one new (not present in the list of datasets in your method) dataset.
As of now it seems that model.py contains domain-specific normalization, multiple sources, etc. - these components may / may not be needed when there is only one (new) dataset given for training?

from unisal.

prashnani avatar prashnani commented on June 5, 2024

thanks @rdroste ! let me give this a try.

from unisal.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.