Code Monkey home page Code Monkey logo

dart-measure's Introduction

DART: Definitely-Agnostic Ranked Trade-off

In a ML pipeline, the process of model selection and model validation is crucial, since it can be what mainly determines the success of a project. Still, when making comparisons of one model against the others, a large portion of practitioners tends to take into account just one statistic at a time; in most of the cases, the accuracy. For example, when tuning the hyperparameters on the validation set using Cross Validation, people tend to select the model that has the highest accuracy. What about the variance of such a model? In most of the cases it tends to be forgotten, even though it is what we really care about since it is a proxy for the stability of our model.

Considering the accuracy alone can be very risky since it may cause an overestimation of the performance of our model.


In order to avoid this, we would like to have a single measure to make such comparisons during the model selection process, which takes into account both the accuracy AND the standard deviation of the model.
Due to the, at the best of my knowledge, absence of such a measure, I have decided to derive a new one considering all the requirements stated above. It is named DART, which stands for Definitely-Agnostic Ranked Trade-off:

  • Definitely-Agnostic: we refuse to suppose there is a single best model based on the accuracy alone;

  • Ranked: at the end, what this measure does is ranking the models to return the best-fitting ones for our problem;

  • Trade-off: we are no more considering only the accuracy as our main metric, since we now have a trade-off with the variance of the model.


From now on we will assume to be using K-Fold Cross Validation, since this measure has been designed principally to be used in this setting and also because it is a standard when doing hyperparameters tuning. Since we need a proxy for the variance of the model, and an accurate estimate of the accuracy, several trainings are required for the same model.

Accuracy is NOT the only statistic that matters

Let's start by providing the formula to compute the DART-measure for a given model:

$$ DART_i = \frac{1 + 1.4427\times \ln(\bar{A_i})}{\mathrm{e}^{\mathrm{p}D_i}} \quad\quad i = 1, ..., \lvert G \rvert $$

where $\lvert G \rvert$ is the number of possible configurations given by our hyperparameters grid, i.e. how many models we will test.

PARAMETERS

  • $\bar{A_i}$: mean accuracy of model i, computed with respect to the k accuracies given by K-Fold Cross Validation;

$$ \bar{A_i} = \frac{1}{\mathrm{k}}\sum_{j}A_{ij} \quad\quad j = 1,..., \mathrm{k} $$

  • $D_i$: standard deviation of the model, computed with respect to the k accuracies;

$$ D_i = \sqrt{\frac{\sum_{j}(A_{ij}-\bar{A_i})^2}{(\mathrm{k}-1)}} \quad\quad j = 1,..., \mathrm{k} $$

  • p: desired precision, fixed and representative of how much importance we give to the stability of the model.

How is the value of the DART-measure related to its inputs? Ideally, it should assume a high value when the accuracy is high and the variance is small; this would mean that we have found a very good model! In an opposite way, when it assumes a low value, it means that the model is very bad, since it has low accuracy and high variance (or, more precisely, a variance too high for our requirements).

Of course, this is just a way of combining the two statistics, and there are plenty of other ways to do that. Nevertheless, the proposed measure gives enough flexibility to choose which characteristic the desired model should have. An important parameter that allows to try several different configurations is the precision, which is a positive real number that represents the importance we give to the stability of the model.


Values, parameters and how to interpret

Let's see now the values the DART-measure can assume by first providing a table to get a quick understanding of how this measure behaves. We use '-' when we don't want to take care of a certain variable on a specific row. Two distinct tables would have been hideous.

Numerator Denominator Accuracy Standard Deviation Performance
1 - 1 - Very Accurate
0 - 0.5 - Random Guessing
< 0 - < 0.5 - Worst Gambler
- 1 - 0 Very stable
- > 1 - > 0 Variance Penalty
- >>> 1 - >> 0 Shaking


NUMERATOR

  • It assumes its biggest value when the accuracy is 1, and its lowest value when the accuracy is 0.5, i.e. random guessing;
  • It assumes negative values whenever the performance is worse than random guessing (REALLY BAD MODEL).

DENOMINATOR

  • It increases exponentially with the standard deviation. It assumes its lowest value when there is no variance;
  • It has not an upper bound, so it can increase without limit depending on the variance.

Summarizing, what this measure does is trying to strike a balance between the accuracy and the stability of a model, making the variance weigh more depending on the precision required by the problem. It is also very nice to observe that its values are normalized between 0 and 1.
Finally, we are also allowed to change the value of the precision, gaining a lot of flexibility. Notice that by setting it to zero, we go back to the usual case in which we ponder just the accuracy; this means we can generalize pretty easily, not bad!

A general, qualitative, behaviour that takes into account all the parameters is shown in the following table:

Accuracy Standard Deviation Desired stability DART*
High High High ☄️✨
High High Low ☄️☄️☄️✨
High Low High ☄️☄️☄️☄️
High Low Low ☄️☄️☄️☄️☄️
Low High High ☄️
Low High Low ☄️✨
Low Low High ☄️☄️☄️
Low Low Low ☄️☄️✨


*We use ☄️ to score the expected goodness of the measure and ✨ to refer to possible variations depending on the specific situation and the exact values. Notice the Standard Deviation is what contributes the most to its value, depending on the desired stability. The lower, the better.


This table is really general and it is shown here just to give a hint about how the DART-measure behaves. However, in real situations, we may have a lot of exceptions and absolute values of different order of magnitudes. And so, why should someone use such an "unstable" measure? Because what we care about is the relative ordering! Fixed a precision, we are able to rank the models in different ways, giving more or less (or no) importance to the stability. If this seems obscure to you, and it probably does, just be a bit more patience and check the notebook; there is a thorough explanation, using plots as well, of how the DART-measure works.

Disclaimer

This is a measure I defined on my own, it is not the scientific result of a paper or of a deep study. I do not know whether it actually makes sense, whether it is useful and whether it can be further improved. Time and experiments will say that. In the meantime, it is not my responsibilty if you use it and you end up blowing your house with a bad model; further tests are still needed.

Easter Eggs

  • The name DART was used by NASA to denote a space mission aimed at testing a method of planetary defense against asteroids. This is why I have choosen to use ☄️ to denote the score of the measure. More info on: Double Asteroid Redirection Test. Also, my girlfriend will be an astrophysicist and she really liked it;

  • Typing Double Asteroid Redirection Test on Google will do something funny. Feel free to try it;

  • Darts are thrown on a dartboard, and the goal is to get as close as possible to the center. Assuming the player is our model, we want it to be both precise and accurate.

dart-measure's People

Contributors

mattizza avatar

Watchers

 avatar

dart-measure's Issues

Study how the measure behaves with different numbers of samples

Let's suppose we have a model that we are cross-validating and we want to compute the measure. However, we don't want to wait to the end in order to do so, because we would like to use PrunedCV to early stop. How the DART measure behaves when we make a comparison between the first folds and the last? In other words, if we compute the DART measure with only the first two folds, we get similar results with respect to when we compute all the K folds? Study the behaviour empirically and see whether it is convenient to stop after only two iterations. The principal issue is related to the fact that the standard deviation we put in the formula is not statistically robust when we only two samples. However, this can be also a good thing. We may use a large number of folds and discard a lot of models if the measure after, say, 5 folds is bad. We will get a good estimate discarding all the bad models. But this would mean that we are also increasing the cost. Study this and find a solution.

Consider exploiting Covariance

In order to get a better estimate of the precision we fix, we may consider the sample covariance as a baseline. It may be useful to find a way to link the sample covariance to the expected variance of a model, i.e. if the sample covariance is high, then the can rescale the standard deviation of the model. Why we would need to do this? In first place, to have an estimate of a fair value of the precision. However, it may be really useful also in PrunedCV where we don't have enough samples to compute a good estimate of the standard deviation of a model. Maybe, using the sample covariance may guide us and allow to make a sort of proportion with the desired precision in order to, even with just 2/3 observations, get a good statistical estimate.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.