Comments (4)
Exactly, you wouldn't need to keep all the domain-specific parameters when you fine-tune the model. When initializing the UNISAL
model class I would set the sources=(my_source,)
where my_source
is the dataset from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') that matches your target dataset most closely. It might be worth trying out all of them. For a static target dataset the best choice would be sources=('SALICON',)
. For video data it's difficult to say a priory but sources=('DHF1K',)
might be a good default option since DHF1K is the most varied of the video datasets. Afterwards, call the model forward with pred = model(x, source='my_source')
. Hope that makes sense.
from unisal.
Hi Ekta, thanks for your interest in our work. A minimal example for fine-tuning the model is a good idea, I'll try to find some time soon the upload one.
However, one difficulty with a general fine-tuning example might be that the optimal fine-tuning method (learning rate, learning rate schedule, batch size, freezing different parts of the network, etc., etc.) really depends on the target dataset. Therefore you could manually load the UNISAL model and plug it into your own training script.
To load the pretrained model you can run something like:
import unisal
model = unisal.model.UNISAL()
model.load_best_weights('unisal/training_runs/pretrained_unisal')
If you want to load the model for one of the training datasets only, you could also run (untested):
my_source = <insert whichever dataset matches your data most closely from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)
(Instead of using strict=False
, which can fail silently, you could also remove the weights with keys 'rnn', 'post_rnn' and keys containing 'DHF1K', 'Hollywood' or 'UCFSports' from the state dict)
If you want to use the model for static data only, you can reduce the model size by loading it without the GRU RNN by running (untested):
my_source = 'SALICON'
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)
In your training code you can then call the model with
# ... your code here
my_source = <one of ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
static= <True or False>
prediction = model(training_batch, source=my_source, static=static)
from unisal.
thanks for your response @rdroste ! Will wait for your minimal example. 👍 Makes sense that with a new dataset, there would be work involved with hyper-parameter tuning.
For plugging unisal into my own training script:
It would be great to know which of the components of model.py are needed when training for just one new (not present in the list of datasets in your method) dataset.
As of now it seems that model.py contains domain-specific normalization, multiple sources, etc. - these components may / may not be needed when there is only one (new) dataset given for training?
from unisal.
thanks @rdroste ! let me give this a try.
from unisal.
Related Issues (13)
- On the pooling method of the backbone network HOT 3
- SALICON data.py error HOT 3
- Regrading to the AUC performance on SALICON testing set
- How are the third line images generated HOT 2
- Why I can't run the project with MIT300?
- Hello, I would like to ask which dataset the weights_best.pth in the training_runs/pretrained_unisal folder is trained with HOT 1
- unisal can't transfer onnx model HOT 1
- Missing script to recreate DHF1K directories HOT 3
- Getting image heatmap from saliency map and original image
- Question about cudatoolkit and environment HOT 3
- Evaluation Metric HOT 2
- SALICON loaded with 0 samples in training HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from unisal.