Code Monkey home page Code Monkey logo

Comments (11)

btaba avatar btaba commented on June 2, 2024 1

Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.

class DomainRandomizationVmapWrapper(Wrapper):

Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the state.info). If you manage to get the same training performance out of it, please send it our way!

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

Wanted to add an example of another workaround: https://github.com/automl/CARL.

In this library for meta-RL, instead of batching environments on the GPU which Brax should support, the CARL-brax environments create VectorizedWrappers from Gymnasium in order to run multiple System variations simultaneously. Which kind of defeats the purpose of GPU parallelization....

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.

class DomainRandomizationVmapWrapper(Wrapper):

Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the state.info). If you manage to get the same training performance out of it, please send it our way!

Hey thanks for the reply. A big obstacle right now in trying to implement something like this is that the pipeline.init and pipeline.step functions are quite rigid. They only receive self, q, qd, _debug as arguments.

So I'm trying to work around this by doing dependency injection for self by converting it into a PyTree such that I can do jax-transforms on pipeline.init etc.. But mocking this object is causing quite a few problems since I'm running into unforeseen dependencies. For this reason I think this approach is not great as this will definitely lead to problems later on.

@btaba Could the pipeline.init and pipeline.apply functions perhaps be extended to receive an optional options dictionary? This would require the API to propagate the options in reset and step from wrappers to base (i.e., like the Gymnasium implementation).

In principle, if these are none then the performance stays the same, and if I want to provide it with options then I can wrap the pipeline module with my custom function that modifies the self.sys.

What do you think?

from brax.

btaba avatar btaba commented on June 2, 2024

Hi @joeryjoery ,

I'm not quite following why you want to add extra args to pipeline.init and pipeline.step. Does something like this not work: jax.vmap(pipeline.init, in_axes=[custom_in_axes, None, None])(sys, q, qd) ?

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

Hey, yes this works. But it's not the problem.

The issue is that I have no easy way to propagate sys variations to that point (at least not in a way that is jittable). So for example, the Ant environment has a reset which looks something like this,

def reset(self, rng: jax.Array) -> State:
  """Resets the environment to an initial state."""
  rng, rng1, rng2 = jax.random.split(rng, 3)
  
  ...
  
  pipeline_state = self.pipeline_init(q, qd)
  obs = self._get_obs(pipeline_state)
  ...

Now suppose I want to wrap Ant I do not have direct access to the self.pipeline_init call. So I cannot modularly jax.vmap(pipeline.init, ...

A way to solve this is to allow options, for example,

def reset(self, rng: jax.Array, *, options: dict | None = None) -> State:
  """Resets the environment to an initial state."""
  rng, rng1, rng2 = jax.random.split(rng, 3)
  
  ...
  
  pipeline_state = self.pipeline_init(q, qd, options=options)  # Pass along here
  obs = self._get_obs(pipeline_state)
  ...

In this way, I can wrap env._pipeline with a function like,

my_env.pipeline_init = my_wrapped_init

def my_wrapped_init(self, q, qd, *, options: dict | None = None):
  sys = self.sys  

  if options is not None:
    variations = some_sampling_function(options)  # returns dict
    sys = self.sys.replace(**variations)
    return jax.vmap(self._pipeline.init, in_axes=(0, None, None, None))(sys, q, qd, self._debug)

  return self._pipeline.init(self.sys, q, qd, self._debug)

from brax.

btaba avatar btaba commented on June 2, 2024

Comments and questions on the proposed changes:

[1] Subsume part of System inside State: You can do this already by adding System to state.info, and re-writing your env code to use state.info['sys'] instead of self.sys. How performant is that implementation for RL workloads? Then we can discuss a potential API change

[2] Add Options to reset: Strong preference here to add your logic to a wrapper, and to split out the vmap case from the non-vmap case into distinct wrappers. It looks like your proposal is similar to the DomainRandomizationVmapWrapper except you want to do the sys.replace at pipeline.init/pipeline.step time? Does this mean that the env.reset and env.step logic won't be accessing the same randomized version of sys?

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

Hey thanks a lot for continuing the discussion.

TLDR; I was overthinking this, and the easy solution is indeed a slight modification of DomainRandomizationVmapWrapper.

  1. The problem with the current DomainRandomizationVmapWrapper is that the randomization is done in the __init__ and not in the reset. If I want to resample variations at every call to reset I instead have to reinstantiate the class, which would mean recompiling reset and step which is costly.

  2. What I did now is make randomization_fn dependent on a random key and call it inside reset, the sampled variations are then replaced inside System and stored inside State.info. These only contain the varied fields so that we don't redundantly pass around data.

In my implementation I also do not include vmap as I think it is much easier to just vmap over the DomainRandomization wrapper. I have not tested performance, but the code is much more readable.

This is what I propose:

class DomainRandomization(brax.envs.Wrapper):
    """Wrapper for Procedural Domain Randomization."""
    
    def __init__(
        self, 
        env: Env, 
        randomization_fn: Callable[[System, jax.Array], System]
    ):
        super().__init__(env)
        self.randomization_fn = randomization_fn

    def env_fn(self, sys: System) -> Env:
        env = self.env
        env.unwrapped.sys = sys
        return env
    
    def reset(self, rng: jax.Array) -> State:
        key_reset, key_var = jax.random.split(rng)
        
        sys = self.env.unwrapped.sys
        variations = self.randomization_fn(sys, key_var)

        new_sys = sys.replace(**variations)
        new_env = self.env_fn(new_sys)
        
        state = new_env.reset(key_reset)
        state = state.replace(info=state.info | {'sys_var': variations})
        
        return state
        
    def step(self, state: State, action: jax.Array) -> State:

        variations = state.info['sys_var']

        sys = self.env.unwrapped.sys
        new_sys = sys.replace(**variations)

        new_env = self.env_fn(new_sys)
        state = new_env.step(state, action)
        
        state = state.replace(info=state.info | {'sys_var': variations})
        
        return state

example usage,

def viscosity_randomizer(system: System, key: jax.Array) -> dict[str, Any]:
    return {'viscosity': jax.random.uniform(key, system.viscosity.shape)}

env = envs.create(
    env_name='ant',
    episode_length=1000,
    action_repeat=1,
    auto_reset=True,
    batch_size=None,
)

wrap = DomainRandomization(env, viscosity_randomizer)

s0 = jax.jit(wrap.reset)(jax.random.key(0))
s1 = jax.jit(wrap.reset)(jax.random.key(321))

print(s0.info['sys_var'], s1.info['sys_var'])
>> {'viscosity': Array(0.10536897, dtype=float32)} {'viscosity': Array(0.3906865, dtype=float32)}


print(w.unwrapped.sys.viscosity)
>> Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
print(w.default_sys.viscosity)
>> 0.0

Or composing with the VmapWrapper,

sbatch = jax.jit(brax.envs.wrappers.training.VmapWrapper(wrap).reset)(
    jax.random.split(jax.random.key(0), 5)
)
print(sbatch.info['sys_var'])
>> {'viscosity': Array([0.6306313 , 0.5778805 , 0.64515114, 0.95315635, 0.24741197],      dtype=float32)}

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

It's not really easy to show that this implementation works here, but if you visualize the results using the code shown in the Colab, you can see that it indeed randomizes the System variables per random key.

https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb#scrollTo=4hHuDp53e4VJ

I also haven't tested performance for RL training. But it's guaranteed faster than using the current DomainRandomizationVmapWrapper due to its non-pure implementation for randomization_fn if your goal is to randomize at every reset call.

from brax.

btaba avatar btaba commented on June 2, 2024

Hi @joeryjoery

I think we tried a version of this implementation. A few comments:

[1] Can you update your impl to make it work for nested fields in sys? You can probably use tree_replace
[2] IIRC passing these extra vars in the info were costly for an RL workload. Can you compare performance with your current version vs. the version at HEAD to see where we're at, and randomize a few more parameters (esp. ones that scale with nv nq ngeom)? Maybe try this on humanoid. So you'd potentially be passing (batch_size, ngeom) parameters in the state.info

FWIW, the impl at HEAD, despite creating a static batch of sys, is enough for sim2real transfer on a quadruped. You can also do multiple resets in training like here (if you're concerned about the static part):

for _ in range(max(num_resets_per_eval, 1)):
# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = (
training_epoch_with_timing(training_state, env_state, epoch_keys)
)
current_step = int(_unpmap(training_state.env_steps))
key_envs = jax.vmap(
lambda x, s: jax.random.split(x[0], s),
in_axes=(0, None))(key_envs, key_envs.shape[1])
# TODO: move extra reset logic to the AutoResetWrapper.
env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

from brax.

joeryjoery avatar joeryjoery commented on June 2, 2024

Hey!

For 1) I was working on something like this, but didn't quite finish today, will update it later. What do you mean with tree_replace is it a private brax api? I was more thinking along the lines of mocking System with a nested dictionary.

For 2), I don't think there is a way around this, we are passing around more data. If the variations are small (like just the viscosity or gravity), then I'd imagine that this is negligible really, but this can grow yes for something like Humanoid and mass or geoms variations. Though, there are some optimizations here I'd imagine.

I'm not suggesting that the other DomainRandomizationVmapWrapper is wrong, if this works well for sim2real that's amazing.

However, for me, I'm specifically looking at fulfiling my research assumptions as well as I can. This assumes random environments at every sampled trajectory, which makes learning a good policy also severely more difficult. Also, In my experiments the data-collection is rarely the bottleneck and moreso the learner I've found (at least for my very specific use-case; meaning PPO with a recurrent network architecture that also does internal matrix inversions).

If I find the time I'll try run the default agent with the current domain-randomization and the one I posted.

from brax.

btaba avatar btaba commented on June 2, 2024

Hi @joeryjoery , tree_replace can be found here:

def tree_replace(

Thanks for the context on [2], I recommend using your own wrapper (for ensuring sampling a new system for every trajectory), looks like you're pretty close to a more general version with the implementation above! Let us know if you have any trouble and please feel free to share any findings (or open a PR)

from brax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.