Code Monkey home page Code Monkey logo

torch_truncnorm's Introduction

torch_truncnorm

Truncated Normal distribution in PyTorch. The module provides:

  • TruncatedStandardNormal class - zero mean unit variance of the parent Normal distribution, parameterized by the cut-off range [a, b] (similar to scipy.stats.truncnorm);
  • TruncatedNormal class - a wrapper with extra loc and scale parameters of the parent Normal distribution;
  • Differentiability wrt parameters of the distribution;
  • Batching support.

Why

I just needed differentiation with respect to parameters of the distribution and found out that truncated normal distribution is not bundled in torch.distributions as of 1.6.0.

Known issues

icdf is numerically unstable; as a consequence, so is rsample. This issue is also seen in torch.distributions.normal.Normal, so it is sort of normal (ba-dum-tss).

Tests

CUDA_VISIBLE_DEVICES=0 python -m tests.test

Links

https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

torch_truncnorm's People

Contributors

jjmorton avatar pierresegonne avatar toshas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

torch_truncnorm's Issues

Plot of pdf seems to suggest that the truncation is faulty

Hey! Thanks for sharing this!

To test briefly the truncated normal distribution, I wanted to verify the pdf and wrote the following snippet

import matplotlib.pyplot as plt
import torch

from sggm.vae_model_helper import TruncatedNormal

a = 0
b = 1
mu = 0.8
std = 0.5

p = TruncatedNormal(loc=mu, scale=std, a=a, b=b, validate_args=True)

x = torch.linspace(-0.5, 1.5, 100)

fig, ax = plt.subplots()
ax.plot(x, p.log_prob(x).exp().flatten())

plt.show()

Resulting in the following plot profile. Did I completely miss the point or is the truncation not happening ?

Thanks :)
Figure_1

Invalid value when calling log_prob after sample

My code looks like

    m = TruncatedNormal(loc, scale, 0, 1)
    action_pt = m.sample()
    return m.log_prob(action_pt)

It looks like action_pt can take the value 1.0 and causes log_prob to raise an error:

    111     def log_prob(self, value):
    112         if self._validate_args:
--> 113             self._validate_sample(value)
    114         return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
    115 

~/.pyenv/versions/3.8.8/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    291         valid = support.check(value)
    292         if not valid.all():
--> 293             raise ValueError(

I don't know if the error is:

  1. that the value 1.0 shouldn't be able to be picked
  2. that the value 1.0 is in the possible interval and shouldn't be called out as impossible

Upload to pypi

Thanks for making this. Very helpful! It'd be really nice if you could pip install it.

device issue in `rsample()`

Hello~ Thanks for this implementation!

In TruncatedStandardNormal.rsample() method, a new tensor is created and will be defaultly located on cpu as shown in thie line. When the loc and scale are from cuda, error will occur in icdf().

Since torch.Distribution class does not have a to(device) method, i think line 97 can be changed to

p = torch.empty(shape).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1).to(self._big_phi_a.device)

Incorrect mean predictions for distributions with a loc much smaller than a

I have observed that an incorrect mean value (by calling the mean method) is predicted when the distribution has a loc values that is much smaller than its lower bound a. More specifically, it predicts a mean lower than a, even though the mean of a truncated gaussian always lies within the bound interval [a,b].

Example:

>>> from truncated_gaussian import TruncatedNormal
>>> TruncatedNormal(0,1,-1,1).mean
tensor(0.)
>>> TruncatedNormal(-0.5,1,-1,1).mean
tensor(-0.1437)
>>> TruncatedNormal(-1,1,-1,1).mean
tensor(-0.2772)
>>> TruncatedNormal(-2,1,-1,1).mean
tensor(-0.4900)
>>> TruncatedNormal(-3,1,-1,1).mean
tensor(-0.6294)
>>> TruncatedNormal(-4,1,-1,1).mean
tensor(-0.7173)
>>> TruncatedNormal(-5,1,-1,1).mean
tensor(-0.7797)
>>> TruncatedNormal(-6,1,-1,1).mean
tensor(-1.0114)
>>> TruncatedNormal(-7,1,-1,1).mean
tensor(-6.9490)
>>> TruncatedNormal(-10,1,-1,1).mean
tensor(-10.)
>>> TruncatedNormal(-100000,1,-1,1).mean
tensor(-100000.)

As you can see, when the mu (loc parameter of the distribution) is equal to -6 or below, the mean of the distribution gets below the bound a=-1, even though this should not happen.

NaN variance with a or b as infinity

Hi,

It seems that when a or b are defined as math.inf, the variance becomes undefined.
Unless I am missing something, the variance should still be well defined for these cases

For example, comparing to the equivalent scipy function scipy.stats.truncnorm:

  • scipy.stats.truncnorm(a=-0.5, b=math.inf, loc=7500, scale=15000) gives variance 1.1E+08
  • TruncatedNormal(a=-0.5, b=math.inf, loc=7500, scale=15000) gives variance NaN

I believe it comes down to the line

self._lpbb_m_lpaa_d_Z = (self._little_phi_b * self.b - self._little_phi_a * self.a) / self._Z

Little phi is correctly calculated as zero for the infinite limit, but 0 * math.inf results in a NaN.

As a quick fix, I replaced infinities with zeros in the line above, for example:

self._lpbb_m_lpaa_d_Z = (self._little_phi_b * (self.b if self.b != math.inf else 0) - self._little_phi_a * (self.a if self.a != math.inf else 0)) / self._Z

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.