Thank you for this great setup! I came across a bug in the binder prediction with MSA (I understand that the latter was only tested for fixbb). I get this following error (also on colab, not just locally):
Traceback (most recent call last):
File "colabdesign/tools/af_design_motifs.py", line 799, in <module>
design_model.design(50, weights={"plddt":0.1,"pae":0.1,"ent":ent})
File "colabdesign/tools/af_design_motifs.py", line 634, in design
self._state, outs, loss = step(self._k, self._state, subkey, opt)
File "colabdesign/tools/af_design_motifs.py", line 589, in step
(loss, outs), grad = self._grad(self._get_params(state), self._params[n], self._inputs, key, opt)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 433, in cache_miss
donated_invars=donated_invars, inline=inline)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 1681, in bind
return call_bind(self, fun, *args, **params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 1693, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 594, in process_call
return primitive.impl(f, *tracers, **params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
*arg_specs).compile().unsafe_call
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 198, in lower_xla_callable
fun, abstract_args, pe.debug_info_final(fun, "jit"))
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1680, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1657, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 2528, in _vjp
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 118, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 522, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "colabdesign/tools/af_design_motifs.py", line 314, in mod
seq_hard = jnp.concatenate([seq_target[None], seq_hard], 1)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in concatenate
for i in range(0, len(arrays), k)]
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in <listcomp>
for i in range(0, len(arrays), k)]
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 557, in concatenate
return concatenate_p.bind(*operands, dimension=dimension)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 272, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 275, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 289, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 440, in linear_jvp
val_out = primitive.bind(*primals, **params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 272, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 275, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1404, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1408, in default_process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 2762, in _concatenate_shape_rule
raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 85, 20), (2, 13, 20).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "colabdesign/tools/af_design_motifs.py", line 799, in <module>
design_model.design(50, weights={"plddt":0.1,"pae":0.1,"ent":ent})
File "colabdesign/tools/af_design_motifs.py", line 634, in design
self._state, outs, loss = step(self._k, self._state, subkey, opt)
File "colabdesign/tools/af_design_motifs.py", line 589, in step
(loss, outs), grad = self._grad(self._get_params(state), self._params[n], self._inputs, key, opt)
File "colabdesign/tools/af_design_motifs.py", line 314, in mod
seq_hard = jnp.concatenate([seq_target[None], seq_hard], 1)
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in concatenate
for i in range(0, len(arrays), k)]
File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in <listcomp>
for i in range(0, len(arrays), k)]
TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 85, 20), (2, 13, 20).
The target is 85, the binder is 13 residue long. How should the 2 arrays be concatenated? I could not figure it out.