Code Monkey home page Code Monkey logo

scarlet2's People

Contributors

charlotteaward avatar jaredcsiegel avatar pmelchior avatar sampsonml avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

sampsonml b-remy

scarlet2's Issues

Adjustable prior weighting?

Do we want to include the ability for the user to adjust the weight put on the gradient returned from the prior in the optimisation step? Empirically I am finding this helps in situations of very crowded scenes, high SNR type observation. Of course this means the deblended galaxy will be more prior dominated but we now have a metric to check this. Here is an example, first this is a tricky scene no changes to prior gradient weighing.

prior_default
prior_default2

Here is with a stronger prior
prior_weight_5
prior_5_2

Can see the residuals are not as good with the stronger prior but the noise reduction better. Changes to optimisation routine may help with better residuals for the stronger prior if we go this route.
The choice is a balance between perhaps more extended runtimes for highly noisy/blended scenes because the de-noising action of the prior takes a longer time. But, it means we are staying more "true" to the data. Just a suggestion, no strong feelings about which is best yet.

correct behavior when constraint and prior are set

It's easy to get into trouble when a parameter specifies both constraint and prior. The problem stems from a transformation that is done to the variable to make it appear (for the sampler or the optimizer) to be unconstrained. E.g a positive constraint amounts to adding $f(x) = \exp(x)$ into the path, which means that the original variable is reparameterized as $y=\log(x)$. For the likelihood, this transformation is transparent because autograd will just apply one more chain rule, but for a prior (e.g. a score model), we compute the gradient $\nabla \log p(x)$, not $\nabla \log p(y)$. This will get people into trouble without even knowing it.

So, we need a warning for those cases that the same transformation $f$ needs to be apply when training and testing the prior network.

Occasional padding error during gradient calculation for prior

For particular sets of images, I get the following padding problem when running scene.fit for sources with the NN prior. I get it with both the HSC and the ZTF priors. Could this be an issue when sources are too close to the edge of the image?

ValueError                                Traceback (most recent call last)
Cell In[31], line 3
      1 #Free up the position parameter of the variable point source
      2 scene.set_info('sources.'+str(indtransient)+'.morphology.center', fixed=False)
----> 3 scene_ = scene.fit(observations_sc2, max_iter=1000, e_rel=1e-6)

File ~/newscarlet/scarlet2/scarlet2/scene.py:176, in Scene.fit(self, observations, max_iter, e_rel, progress_bar, callback, **kwargs)
    173 with tqdm.trange(max_iter, disable=not progress_bar) as t:
    174     for step in t:
    175         # optimizer step
--> 176         scene_, loss, opt_state = _make_step(scene, observations, optim, opt_state, filter_spec=filter_spec,
    177                                              constraint_fn=constraint_fn)
    178         # Log the loss in the tqdm progress bar
    179         t.set_postfix(loss=f"{loss:08.2f}")

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/equinox/_jit.py:103, in _JitWrapper.__call__(self, *args, **kwargs)
    102 def __call__(self, /, *args, **kwargs):
--> 103     return self._call(False, args, kwargs)

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/equinox/_jit.py:99, in _JitWrapper._call(self, is_lower, args, kwargs)
     97         out = self._cached(dynamic, static)
     98 else:
---> 99     out = self._cached(dynamic, static)
    100 return _postprocess(out)

    [... skipping hidden 12 frame]

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/equinox/_jit.py:37, in _filter_jit_cache.<locals>.fun_wrapped(dynamic, static)
     35 fun = hashable_combine(dynamic_fun, static_fun)
     36 args, kwargs = hashable_combine(dynamic_spec, static_spec)
---> 37 out = fun(*args, **kwargs)
     38 dynamic_out, static_out = partition(out, is_array)
     39 return dynamic_out, Static(static_out)

File ~/newscarlet/scarlet2/scarlet2/scene.py:243, in _make_step(model, observations, optim, opt_state, filter_spec, constraint_fn)
    240         return loss_fn(model)
    242     diff_model, static_model = eqx.partition(model, filter_spec)
--> 243     loss, grads = filtered_loss_fn(diff_model, static_model)
    245 updates, opt_state = optim.update(grads, opt_state)
    246 model_ = eqx.apply_updates(model, updates)

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/equinox/_ad.py:59, in _ValueAndGradWrapper.__call__(self, x, *args, **kwargs)
     56     return self._fun(_x, *_args, **_kwargs)
     58 diff_x, nondiff_x = partition(x, is_inexact_array)
---> 59 return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)

    [... skipping hidden 8 frame]

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/equinox/_ad.py:56, in _ValueAndGradWrapper.__call__.<locals>.fun_value_and_grad(_diff_x, _nondiff_x, *_args, **_kwargs)
     53 @ft.partial(jax.value_and_grad, has_aux=self._has_aux, **self._gradkwargs)
     54 def fun_value_and_grad(_diff_x, _nondiff_x, *_args, **_kwargs):
     55     _x = combine(_diff_x, _nondiff_x)
---> 56     return self._fun(_x, *_args, **_kwargs)

File ~/newscarlet/scarlet2/scarlet2/scene.py:240, in _make_step.<locals>.filtered_loss_fn(diff_model, static_model)
    237 @eqx.filter_value_and_grad
    238 def filtered_loss_fn(diff_model, static_model):
    239     model = eqx.combine(diff_model, static_model)
--> 240     return loss_fn(model)

File ~/newscarlet/scarlet2/scarlet2/scene.py:228, in _make_step.<locals>.loss_fn(model)
    226 parameters = model.get_parameters(return_info=True) 
    227 log_like = sum(obs.log_likelihood(pred) for obs in observations)
--> 228 log_prior = sum(info["prior"].log_prob(p)
    229                 for name, (p, info) in parameters.items()
    230                 if info["prior"] is not None
    231                 )
    232 return -(log_like + log_prior)

File ~/newscarlet/scarlet2/scarlet2/scene.py:228, in <genexpr>(.0)
    226 parameters = model.get_parameters(return_info=True) 
    227 log_like = sum(obs.log_likelihood(pred) for obs in observations)
--> 228 log_prior = sum(info["prior"].log_prob(p)
    229                 for name, (p, info) in parameters.items()
    230                 if info["prior"] is not None
    231                 )
    232 return -(log_like + log_prior)

    [... skipping hidden 5 frame]

File ~/newscarlet/scarlet2/scarlet2/nn.py:155, in NNPrior.log_prob_fwd(x)
    153 def log_prob_fwd(x):
    154     # Returns primal output and residuals to be used in backward pass by f_bwd
--> 155     nn_grad = calc_grad(x, custom_model, log_flag)
    156     return 0.0, nn_grad

File ~/newscarlet/scarlet2/scarlet2/nn.py:91, in calc_grad(x, trained_model, log_space)
     89 t = 0.0 # corresponds to noise free gradient
     90 x = jnp.float32(x) # cast to float32
---> 91 x, ScoreNet, pad_lo, pad_hi, pad = pad_fwd(x, trained_model, log_space)
     92 assert (x.shape[1] % 32) == 0, f"image size must be 32 or 64, got: {x.shape[1]}"
     93 # Scorenet needs (n, 64, 64) or (n, 32, 32)

File ~/newscarlet/scarlet2/scarlet2/nn.py:73, in pad_fwd(x, trained_model, log_space)
     71             x = jnp.pad(x, ((pad_lo, pad_hi), (pad_lo, pad_hi)), 'constant', constant_values=-6)       # minimum value of log-space
     72         else:
---> 73             x = jnp.pad(x, ((pad_lo, pad_hi), (pad_lo, pad_hi)), 'constant', constant_values=0)
     74 return x, ScoreNet, pad_lo, pad_hi , pad

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1757, in pad(array, pad_width, mode, **kwargs)
   1754 end_values = kwargs.get('end_values', 0)
   1755 reflect_type = kwargs.get('reflect_type', "even")
-> 1757 return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)

    [... skipping hidden 12 frame]

File ~/.conda/envs/scarlet/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1685, in _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
   1682   raise ValueError(f"Expected pad_width to have shape {(nd, 2)}; got {pad_width_arr.shape}.")
   1684 if np.any(pad_width_arr < 0):
-> 1685   raise ValueError("index can't contain negative values")
   1687 if mode == "constant":
   1688   return _pad_constant(array, pad_width, asarray(constant_values))

ValueError: index can't contain negative values

Problem with relative step sizes

Using either
from scarlet2 import relative_step\ Parameter(jnp.ones(5), constraint=constraints.positive, stepsize=relative_step)
or
Parameter(jnp.ones(5), constraint=constraints.positive, stepsize=lambda x: jnp.linalg.norm(x)*1e-1)
is yielding nans in the loss function and final models that are all nan.

Bug in calculating convergence criteria when there is a point source model in the scene

Scarlet cannot access morphology.data for a PointSource model when accessing the gradients.

AttributeError                            Traceback (most recent call last)
Cell In[22], line 3
----> 3 scene_ = scene.fit(observations, max_iter=300, e_rel=1e-6)

File ~/scarletpriorfix/scarlet2/scarlet2/scene.py:201, in Scene.fit(self, observations, max_iter, e_rel, progress_bar, callback, **kwargs)
    197 crit_spec = lambda x, x_: jnp.linalg.norm(x - x_) < e_rel * jnp.linalg.norm(x_)
    198 # gradients converged
    199 converged_grads = tuple(
    200     crit_grad(grads_, grads) for (grads_, grads) in zip(
--> 201         [ grads_.morphology.data for grads_ in grads_.sources ], 
    202         [ grads.morphology.data for grads in grads.sources ]))
    203 # spectrum converged
    204 converged_spec = tuple(
    205     crit_spec(spec, spec_) for (spec, spec_) in zip(
    206         [ spec.parameters['spectrum.data']for spec in scene.sources ], 
    207         [ spec.parameters['spectrum.data']for spec in scene_.sources ]))

File ~/scarletpriorfix/scarlet2/scarlet2/scene.py:201, in <listcomp>(.0)
    197 crit_spec = lambda x, x_: jnp.linalg.norm(x - x_) < e_rel * jnp.linalg.norm(x_)
    198 # gradients converged
    199 converged_grads = tuple(
    200     crit_grad(grads_, grads) for (grads_, grads) in zip(
--> 201         [ grads_.morphology.data for grads_ in grads_.sources ], 
    202         [ grads.morphology.data for grads in grads.sources ]))
    203 # spectrum converged
    204 converged_spec = tuple(
    205     crit_spec(spec, spec_) for (spec, spec_) in zip(
    206         [ spec.parameters['spectrum.data']for spec in scene.sources ], 
    207         [ spec.parameters['spectrum.data']for spec in scene_.sources ]))

AttributeError: 'GaussianMorphology' object has no attribute 'data'

Stepsize function evaluates only at the beginning of Scene.fit``

PR #16 introduced stepsize functions, i.e. a function f(parameter) = stepsize, so that one can have e.g. relative stepsizes. But it does this evaluation only once, at the beginning of Scene.fit. It should do it at every step to allow for significant changes in the parameter to alter the step sizes.

port initialization routines from scarlet1

scarlet1 has very well designed initializations, which we should port as much as possible. In scarlet1, we assumed we took control of the entire source creation and initialized all of its parameters. But the approach of specifying sources is different in scarlet2, where all parameters are exposed to the user:

Source(
        center,
        ArraySpectrum(Parameter(spectrum, constraint=constraints.positive, stepsize=1)),
        ArrayMorphology(Parameter(morph, constraint=constraints.positive, stepsize=1e-1))
    )

We could extend it to something very modular like (following flax/stax style of initialization):

def init_spectrum(obs, center):
    return obs.data[center]

Source(
        center,
        ArraySpectrum(Parameter(None, constraint=constraints.positive, stepsize=1, init_fn=init_spectrum)),
        ArrayMorphology(Parameter(morph, constraint=constraints.positive, stepsize=1e-1))
    )

but then init_spectrum would need to know how to get obs, and center. Much more transparent to the user and enabled by the state-preserving equinox is something like this:

spectrum = init_spectrum(obs, center)
Source(
        center,
        ArraySpectrum(Parameter(spectrum, constraint=constraints.positive, stepsize=1)),
        ArrayMorphology(Parameter(morph, constraint=constraints.positive, stepsize=1e-1))
    )

For the morphology, we would do the same, with the main problem that we don't have a port of the monotonicity operator from scarlet1 (yet). I think we could port if it we don't require it to be jitable.

Convergence criteria in Scene.py

Recommend to move all prior scarlet usage to the main branch as matt-nn-branch is in the process of being merged/removed.

The only functionality which is lost is the convergence criteria for the optimisation stepping. This was not ideal in the matt-nn-branch so I will create a new routine for this based on the relative gradients returned by the ADAM optimiser. Tracking for this will be here.

Galaxy model taking over point source in a galaxy+point source blend

I've been coming across some tricky issues when trying to deblend a galaxy with a bright point source and wanted to crowd source some ideas. This is the scarlet1 model, which look pretty good when rendered but has more noise than usual around the point source in the residuals, which presumably arises because it's so bright.

image

When I model this in Scarlet2, source 0 (a normal source with the HSC prior) tries to model the point source, while source 1 (a point source) has its best-fit flux set to 0. Then the actual galaxy is poorly modeled:

image
image

I tried making a smaller box size just to see what happens when the galaxy doesn’t completely overlap the point source, and things go terribly wrong in a way that I can’t quite understand:

image

Is there a good way to make sure the galaxy model doesn't use its flexibility to model the bright point source? There is no variability in this case - just 3 images in different bands.

Is step size updated by adam?

#27 brought a change, where the iteration counter has become an array and we pass it to the update function, namely by optim.update(..., [model, it]). Our custom scale_by_stepsize splits the model parameters and iteration counter correctly, but what does optax.scale_by_adam do? Does it run gradient descent on it? Please check!

scarlet2/scarlet2/scene.py

Lines 141 to 164 in c79f8f2

def scale_by_stepsize() -> base.GradientTransformation:
# adapted from optax.scale_by_param_block_norm()
def init_fn(params):
del params
return base.EmptyState()
def update_fn(updates, state, params):
params,it = params
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree_util.tree_map(
# lambda u, step, param: -step * u if not callable(step) else -step(param,niter) * u,
lambda u, s, p: -s * u if not callable(s) else -s(p, it) * u,
# minus because we want gradient descent
updates, steps, params)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
# run adam, followed by stepsize adjustments
optim = optax.chain(
optax.scale_by_adam(**kwargs),
scale_by_stepsize(),
)

In built initializations

Functions for Scarlet2 default initialisations for the morphology and spectrum are included here: https://github.com/pmelchior/scarlet2/blob/initializations/scarlet2/initialization.py

To run may simply modify your fit routine as so adding two lines

init_spec = scarlet2.initialization.init_spectrum(obs, center)
init_morph = scarlet2.initialization.init_morphology(obs, center)

So he new routine looks like,

with Scene(model_frame) as scene:
    for center in centers:
        # initialize the prior
        prior = nn.ScorePrior(
            model=prior_model, transform=transform, model_size=model_size
        )
        # initialize the spctrum and morphology with Scarlet2
        init_spec = scarlet2.initialization.init_spectrum(obs, center)
        init_morph = scarlet2.initialization.init_morphology(obs, center)
        
        # construct the source
        Source(
            center,
            ArraySpectrum(
                Parameter(init_spec, constraint=constraints.positive, stepsize=5e-2)
            ),
            ArrayMorphology(Parameter(init_morph, prior=prior, stepsize=5e-3)),
        )

# now fit the model
scene_fit = scene.fit(obs, max_iter=220, e_rel=1e-4)  
renders = obs.render(scene_fit())

Currently, fitting for the boxsize is not done very nicely, so brainstorming and comments welcome. This is also the speed bottleneck in the morphology initialisations too.

Overall this works well producing some reasonable initialisations especially given our more robust optimisation scheme.
inits
renders

Attach a HMC sampler to get full uncertainties

With gradients computed automatically, we should provide an option to run a sampler. It probably needs to accept pytrees as parameters, but we may be able to collapse these to standard parameter lists. I've had some trouble getting blackjax to run, but here's what I've got:

import blackjax
rng_key = random.PRNGKey(0)
logdensity_fn = lambda x: -loss_fn(x, data)
init_state = eqx.filter(source, eqx.is_array)

adapt = blackjax.window_adaptation(
   blackjax.nuts, logdensity_fn, initial_step_size=0.01
)
last_state, kernel, _ = adapt.run(rng_key, init_state)

# adapt = blackjax.meads(logdensity_fn, 1)
# last_state, kernel, _ = adapt.run(rng_key, init_state)

print(last_state.position.spectrum(), loss_fn(last_state.position, data))

def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos

num_sample = 1000
states, infos = inference_loop(rng_key, kernel, last_state, num_sample)

print(states.potential_energy[-10:])
print(states.position.spectrum()[-10:])

One of the problems is that the window adaptation does not seem to work, and without it the inference is just a mess.

stop_gradient instead of static parameters

From this comment by Patrick Kidger it's clear that jax has a wider definition of staticness, namely that it's invisible to every jax transformation. We normally only want to have grad ignore the parameters, for which he recommends jax.lax.stop_gradient.

Doing so is probably cleaner than what we do right now, but we need to find out how/where we call it programmatically and how it can be undone.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.