Code Monkey home page Code Monkey logo

evidential-deep-learning's Introduction

Evidential Deep Learning

"All models are wrong, but some — that know when they can be trusted — are useful!"

- George Box (Adapted)

This repository contains the code to reproduce Deep Evidential Regression, as published in NeurIPS 2020, as well as more general code to leverage evidential learning to train neural networks to learn their own measures of uncertainty directly from data!

Setup

To use this package, you must install the following dependencies first:

  • python (>=3.7)
  • tensorflow (>=2.0)
  • pytorch (support coming soon)

Now you can install to start adding evidential layers and losses to your models!

pip install evidential-deep-learning

Now you're ready to start using this package directly as part of your existing tf.keras model pipelines (Sequential, Functional, or model-subclassing):

>>> import evidential_deep_learning as edl

Example

To use evidential deep learning, you must edit the last layer of your model to be evidential and use a supported loss function to train the system end-to-end. This repository supports evidential layers for both fully connected and convolutional (2D) layers. The evidential prior distribution presented in the paper follows a Normal Inverse-Gamma and can be added to your model:

import evidential_deep_learning as edl
import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(64, activation="relu"),
        tf.keras.layers.Dense(64, activation="relu"),
        edl.layers.DenseNormalGamma(1), # Evidential distribution!
    ]
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3), 
    loss=edl.losses.EvidentialRegression # Evidential loss!
)

Checkout hello_world.py for an end-to-end toy example walking through this step-by-step. For more complex examples, scaling up to computer vision problems (where we learn to predict tens of thousands of evidential distributions simultaneously!), please refer to the NeurIPS 2020 paper, and the reproducibility section of this repo to run those examples.

Reproducibility

All of the results published as part of our NeurIPS paper can be reproduced as part of this repository. Please refer to the reproducibility section for details and instructions to obtain each result.

Citation

If you use this code for evidential learning as part of your project or paper, please cite the following work:

@article{amini2020deep,
  title={Deep evidential regression},
  author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

evidential-deep-learning's People

Contributors

aamini avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

evidential-deep-learning's Issues

var in hello_world.py

Thank you a lot for the excellent codes and the paper. I would like to ask if the variable var in the function plot_predictions in hello_world.py is actually standard deviation? And if it only contains epistemic error? If so, how to implement aleatoric uncertainty in the same code.

evidential DL for anomaly detection in time series

Hello, I'm writing my thesis about anomaly detection in time series and I would like to implement your algorithm to recognize anomalies.
My idea is to calculate the train (only ID samples) entropy and then the test (ID+OOD samples) one as you did in your paper.
Since I could have time series with some features integer and not float, this can be a problem because the assumption of the method is the normality of the distribution right?

In your opinion, can be this method useful? Do you have some ideas or suggestions?

Thanks in advance!

TypeError: tf__Dirichlet_SOS() missing 1 required positional argument: 't'

First of all, thanks for your valuable contribution. The EDL concept is interesting.

I have tried the EDL for a simple classification task like:

import evidential_deep_learning as edl
import tensorflow as tf
import sklearn
import sklearn.datasets

iris = sklearn.datasets.load_iris()
train, test, labels_train, labels_test = sklearn.model_selection.train_test_split(iris.data, iris.target, train_size=0.80)

model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(4, activation="relu"),
        tf.keras.layers.Dense(64, activation="relu"),
        edl.layers.DenseDirichlet(3), # Evidential distribution!
    ]
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3), 
    loss=edl.losses.Dirichlet_SOS # Evidential loss!
)

history = model.fit(train, labels_train, batch_size=1024, epochs=32, verbose=0, validation_split=0.2)

However, I got the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-26b7527e8794> in <module>
     19 )
     20 
---> 21 history = model.fit(train, labels_train, batch_size=1024, epochs=32, verbose=0, validation_split=0.2)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    869       # This is the first call of __call__, so we have to initialize.
    870       initializers = []
--> 871       self._initialize(args, kwds, add_initializers_to=initializers)
    872     finally:
    873       # At this point we know that the initialization is complete (or less

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    724     self._concrete_stateful_fn = (
    725         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 726             *args, **kwds))
    727 
    728     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2967       args, kwargs = None, None
   2968     with self._lock:
-> 2969       graph_function, _ = self._maybe_define_function(args, kwargs)
   2970     return graph_function
   2971 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3359 
   3360           self._function_cache.missed.add(call_context_key)
-> 3361           graph_function = self._create_graph_function(args, kwargs)
   3362           self._function_cache.primary[cache_key] = graph_function
   3363 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3204             arg_names=arg_names,
   3205             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206             capture_by_value=self._capture_by_value),
   3207         self._function_attributes,
   3208         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    988         _, original_func = tf_decorator.unwrap(python_func)
    989 
--> 990       func_outputs = python_func(*func_args, **func_kwargs)
    991 
    992       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    632             xla_context.Exit()
    633         else:
--> 634           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    635         return out
    636 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

TypeError: in user code:

    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)

    TypeError: tf__Dirichlet_SOS() missing 1 required positional argument: 't'

I was wondering if you could kindly help me to fix this problem.

Deep Evidential Regression on Toy Dataset

Thank you for the interesting paper and publishing your code. In order to get a better understanding of your method, I have created a google colab (in pytorch) with a toy regression problem. I have implemented your methodology as well as the one from a follow up paper by Meinert et al. 2022. However, neither seem to perform very well, as especially the epistemic uncertainty is not behaving as expected (see below). Therefore, I was wondering if you could give me a hint about what I am missing/doing wrong or a possible explanation. Thanks in advance!

Typo in S12 -> S13 ?

image

deriving the epistemic uncertainty, from S12 to S13, the negative term disappears for the 2nd term [(beta/(alpha-1))/v]

Loss goes to NaN

For a regression task, I am using a mid-size CNN consisting of Conv and MaxPool layers in the first layers and Dense layers in the last layers.

This is how I integrate the evidential loss (Before I used MSE loss):

optimizer = tf.keras.optimizers.Adam(learning_rate=7e-7)
def EvidentialRegressionLoss(true, pred):
    return edl.losses.EvidentialRegression(true, pred, coeff=CONFIG.EDL_COEFF)
model.compile(
    optimizer=optimizer,
    loss=EvidentialRegressionLoss,
    metrics=["mae"]
)

This is how I integrated the layer DenseNormalGamma:

    # lots of ConvLayers
    model.add(layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu"))
    model.add(layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(1024, activation="relu"))
    model.add(layers.Dense(128, activation="relu"))

    model.add(edl.layers.DenseNormalGamma(1))  # Instead of Dense(1)

    return model

Here is the issue I am facing:

  • Before introducing evidential-deep-learning I used 0.0007=7e-4 as a learning rate that worked well.
  • Now I get loss=NaN with this learning rate, also if I make it smaller (7e-7) I get loss=NaN, mostly already in the very first epoch of training
  • If I set the learning rate ridiculously low (7e-9) I don't get NaN but of course the network is not learning fast enough

Is there any obvious mistake I make? Any thoughts and help appreciated

Mistake in (S26)?

First of all thanks a lot for this repo and the papers! I think there is a mistake in (S26) that is also present in the code.
image

Both Student-t and Normal Inverse Gamma often have a parameter named ν. It seems that these two parameters are confused in the log(π/ν) summand. I think that summand should be log(π/ν_student) = log(π/(2α)) instead.
I think this is also wrong in the code:

When is the pytorch version available?

Hi, it is a very interesting and valuable work. I have two questions:
1.When the pytorch version is available?
2.I notice the regression problem is used in the paper. What about the classification problem? Is it suitable for the classification problem?
Thanks.

NIG_Loss smaller than zero!

NIG_Loss smaller than zero! NIG_loss becomes negative during training, I don't know why, thank you!

Epoch: 0, iter: 360, abs_loss: 0.075639, nll_loss: -0.842837, reg_loss:0.364714 ,acc: 0.952216, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 361, abs_loss: 0.075609, nll_loss: -0.843703, reg_loss:0.364581 ,acc: 0.952219, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 362, abs_loss: 0.075563, nll_loss: -0.844643, reg_loss:0.364431 ,acc: 0.952307, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 363, abs_loss: 0.075479, nll_loss: -0.846680, reg_loss:0.364073 ,acc: 0.952395, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 364, abs_loss: 0.075361, nll_loss: -0.849301, reg_loss:0.363639 ,acc: 0.952526, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 365, abs_loss: 0.075248, nll_loss: -0.851855, reg_loss:0.363197 ,acc: 0.952655, lr_pre: 0.000030, lr_last: 0.000300
Epoch: 0, iter: 366, abs_loss: 0.075118, nll_loss: -0.854777, reg_loss:0.362667 ,acc: 0.952784, lr_pre: 0.000030, lr_last: 0.000300

Minimum size of training set

There is empirical evidence that Chemprop can learn meaningful representations from a dataset of at least 1K pairs SMILES/properties. I think it has been the case for most of the experiments I have carried out. Now, when applying evidential deep learning, this does not seem to hold anymore. From my understanding, that might be because we are predicting in the output layer the parameters to parameterize a normal inverse gamma distribution and modeling that might require more data (I am ok with that). Is this assumption correct?

How did I get to this point? I took a 1.2K data points dataset and randomly partitioned 80%/20% for training and test set, respectively. If I use Chemprop for a regression task without evidential learning, metrics to evaluate predictive power (MAE, RMSE, and R2) are descent. But if I use the same dataset to train the evidential learning case, then the model cannot predict the test set. Of course, it also lets me know that it is very uncertain about making predictions, but I was surprised to see a degradation of generalization.

Any thoughts would be appreciated.

Best,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.