x[1,k] ~ H[1]*z[1,k]
for d = 2:D
x[d,k] ~ x[d-1,k] + H[d]*z[d,k]
end
# Preallocate
z = randomvar(D,T)
x = randomvar(D,T)
y = datavar(Float64, T)
### k = 1
# Initial state
for d = 1:D
z[d,1] ~ NormalMeanVariance(0.0, 1.0)
end
# Measurement function
x[1,1] ~ H[1]*z[1,1]
for d = 2:D
x[d,1] ~ x[d-1,1] + H[d]*z[d,1]
end
# Likelihood
y[1] ~ NormalMeanVariance(x[D,1], 1.0)
### k >= 1
for k = 2:T
# State transition
for d = 1:D
z[d,k] ~ NormalMeanVariance(z[d,k-1], 1.0)
end
# Measurement function
x[1,k] ~ H[1]*z[1,k]
for d = 2:D
x[d,k] ~ x[d-1,k] + H[d]*z[d,k]
end
# Likelihood
y[k] ~ NormalMeanVariance(x[D,k], 1.0)
end
return y, x, z
end
ReactiveMP.make_actor(x::Matrix{ <: RandomVariable }, ::ReactiveMP.KeepLast) = buffer(Marginal, size(x))
ReactiveMP.make_actor(::Matrix{ <: RandomVariable }, ::ReactiveMP.KeepEach) = keep(Matrix{Marginal})
# Test data
T = 100
D = 7
H = randn(7)
y = randn(T)
# Define factorization
constraints = @constraints begin
q(x,z) = q(x)q(z)
end
# Automatic inference API
results = inference(
model = Model(mwe, H, D=D, T=T),
data = (y = y,),
constraints = constraints,
options = (limit_stack_depth = 100,),
initmessages = (
z = NormalMeanVariance(0.0, 1.0),
x = NormalMeanVariance(0.0, 1.0)
),
initmarginals = (
z = NormalMeanVariance(0.0, 1.0),
x = NormalMeanVariance(0.0, 1.0)
),
returnvars = (z = KeepLast(), x = KeepLast()),
free_energy = true,
showprogress = true,
iterations = 100
)
# Extract approximate posteriors
qz = results.posteriors[:z]
qx = results.posteriors[:x]
# Plot difference between H*z and x
plot(1:T, [H'*mean.(qz[:,k]) for k=1:T], color="purple", label="H*qz")
plot!(1:T, [mean(qx[D,k]) for k=1:T], color="green", label="qx")