Code Monkey home page Code Monkey logo

Comments (2)

torfjelde avatar torfjelde commented on August 15, 2024 1

The default implementation of with_logabsdet_jacobian for a Bijector is (transform(b, x), logabsdetjac(b, x)), but since you haven't defined logabsdetjac(::Inverse{Tanh}, y), you also hit the default impl of this, which is -logabsdetjac(inverse(b), inverse(b)(y)).

You then get a stack overlflow error because transform(::Inverse{Tanh}, y) is also not defined (Scale does not have an Inverse{<:Scale} implementation because its inverse is just inverting the scale factor and returning a new Scale).

In fact, here you don't really need to mess around with the Bijector stuff at all, since tanh is already a function so you don't need a "new" representation of it + its inverse atanh is similarly already defined.

I'd implement the above as:

using ChangesOfVariables, InverseFunctions, StatsFuns

InverseFunctions.inverse(::typeof(tanh)) = atanh
InverseFunctions.inverse(::typeof(atanh), x) = tanh

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(tanh), x::Real)
    y = tanh(x)
    return y, _logabsdetjac_tanh(x)
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(atanh), y::Real)
    x = atanh(y)
    return x, -_logabsdetjac_tanh(x)
end

# Use the irrational representation `StatsFuns.logtwo` to defer type-promotion.
# Similarly, I've removed all explicit usages of `Float64`, e.g. converted `2.0` to `2`
# to allow type-promotion to do its thing rather than forcing usage of `Float64`.
_logabsdetjac_tanh(x::Real) = 2 * (StatsFuns.logtwo - x - softplus(-2 * x))

If you want a version that is supposed to act elementwise, then you can use Bijectors.elementwise(f):

julia> using Bijectors

julia> elementwise(tanh)(rand(10))
10-element Vector{Float64}:
 0.22076308094447367
 0.06828859488600718
 0.3496810171644955
 0.02413051400382789
 0.6228303792319176
 0.5772825278828461
 0.7370222452215927
 0.45865543543291265
 0.6128386429868988
 0.7094298145373448

julia> with_logabsdet_jacobian(elementwise(tanh), rand(10))
([0.5475308984676883, 0.7498770212815672, 0.11406375475912378, 0.04598020777639154, 0.41278517115619784, 0.3067650082385844, 0.6441810700388316, 0.7430095366528289, 0.7023124306195118, 0.2806093226497268], -3.5844094465162772)

from bijectors.jl.

hanyas avatar hanyas commented on August 15, 2024

I've also tried to define the bijector by following a similar recipe to that of the Scale bijector but without success

struct Tanh <: Bijector end

with_logabsdet_jacobian(b::Tanh, x) = transform(b, x), logabsdetjac(b, x)

transform(b::Tanh, x) = tanh(x)
transform(b::Tanh, x::AbstractVecOrMat) = tanh.(x)
transform(ib::Inverse{<:Tanh}, y) = transform(atanh, y)
transform(ib::Inverse{<:Tanh}, y::AbstractVecOrMat) = transform(@. atanh, y)

logabsdetjac(b::Tanh, x::Real) = _logabsdetjac_tanh(b, x, Val(0))
function logabsdetjac(b::Tanh, x::AbstractArray{<:Real,N}) where {N}
    return _logabsdetjac_tanh(b, x, Val(N))
end

_logabsdetjac_tanh(b::Tanh, x::Real, ::Val{0}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
_logabsdetjac_tanh(b::Tanh, x::AbstractVector, ::Val{1}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x)) * length(x)

from bijectors.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.