Code Monkey home page Code Monkey logo

flybody's Introduction

flybody: fruit fly body model for MuJoCo physics

preprint

pytest workflow python versions lint tf

flybody is an anatomically-detailed body model of the fruit fly Drosophila melanogaster for MuJoCo physics simulator and reinforcement learning applications.

The fly model was developed in a collaborative effort by Google DeepMind and HHMI Janelia Research Campus.

We envision our model as a platform for fruit fly biophysics simulations and for modeling neural control of sensorimotor behavior in an embodied context; see our accompanying publication.

Getting Started

The fruit fly body model lives in this directory. To visualize it, you can drag-and-drop fruitfly.xml or floor.xml to MuJoCo's simulate viewer.

Beginning interacting with the model via Python is as simple as:

from dm_control import mujoco

physics = mujoco.Physics.from_xml_path('flybody/fruitfly/assets/fruitfly.xml')  # Load model.
physics.step()  # Step simulation.

The quickest way to get started with flybody is to take a look at a tutorial notebook or Open In Colab.

Also, this notebook shows examples of the flight, walking, and vision-guided flight RL task environments.

To train the fly, try the distributed RL training script, which uses Ray to parallelize the DMPO agent training.

Installation

Follow these steps to install flybody:

Option 1: Installation from cloned local repo

  1. Clone this repo and create a new conda environment:

    git clone https://github.com/TuragaLab/flybody.git
    cd flybody
    conda env create -f flybody.yml
    conda activate flybody

    flybody can be installed in one of the three modes described next. Also, for installation in editable (developer) mode, use the commands as shown. For installation in regular, not editable, mode, drop the -e flag.

  2. Core installation: minimal installation for experimenting with the fly model in MuJoCo or prototyping task environments. ML dependencies such as Tensorflow and Acme are not included and policy rollouts and training are not automatically supported.

    pip install -e .
  3. ML extension (optional): same as core installation, plus ML dependencies (Tensorflow, Acme) to allow running policy networks, e.g. for inference or for training using third-party agents not included in this library.

    pip install -e .[tf]
  4. Ray training extension (optional): same as core installation and ML extension, plus Ray to also enable distributed policy training in the fly task environments.

    pip install -e .[ray]

Option 2: Installation from remote repo

  1. Create a new conda environment:
    conda create --name flybody python=3.10 pip ipython cudatoolkit cudnn=8.2.1=cuda11.3_0
    conda activate flybody
    Proceed with installation in one of the three modes (described above):
  2. Core installation:
    pip install git+https://github.com/TuragaLab/flybody.git
  3. ML extension (optional):
    pip install "flybody[tf] @ git+https://github.com/TuragaLab/flybody.git"
  4. Ray training extension (optional):
    pip install "flybody[ray] @ git+https://github.com/TuragaLab/flybody.git"

Additional configuring

  1. You may need to set MuJoCo rendering environment varibles, e.g.:

    export MUJOCO_GL=egl
    export MUJOCO_EGL_DEVICE_ID=0
  2. Also, for the ML and Ray extensions, LD_LIBRARY_PATH may require an update, e.g.:

    export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/your/path/to/miniconda3/envs/flybody/lib
  3. You may want to run pytest to test the main components of the flybody installation.

Citing flybody

See our accompanying publication. Thank you for your interest in our fly model:)

@article{flybody,
  title = {Whole-body simulation of realistic fruit fly locomotion with
           deep reinforcement learning},
  author = {Roman Vaxenburg and Igor Siwanowicz and Josh Merel and Alice A Robie and
            Carmen Morrow and Guido Novati and Zinovia Stefanidi and Gwyneth M Card and
            Michael B Reiser and Matthew M Botvinick and Kristin M Branson and
            Yuval Tassa and Srinivas C Turaga},
  journal = {bioRxiv},
  doi = {https://doi.org/10.1101/2024.03.11.584515},
  url = {https://www.biorxiv.org/content/10.1101/2024.03.11.584515},
  year = {2024},
}

flybody's People

Contributors

vaxenburg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

flybody's Issues

Training Checkpoint Loading Error

Hi Team,

When I tried to load checkpoint to continue training via the DMPOConfig.checkpoint_to_load = {run_name}/checkpoints/dmpo_learner/ckpt-4.index, yielded the following AssertionError inside create_dual_variables_once.

ray.exceptions.RayTaskError(AssertionError): ray::Learner.run() (pid=3441185, ip=10.244.10.88, actor_id=155580693110898c9f9447d701000000, repr=<flybody.agents.ray_distributed_dmpo.Learner object at 0x7f0c90c71240>)
  File "/root/vast/scott-yang/flybody/flybody/agents/ray_distributed_dmpo.py", line 208, in run
    self.step()
  File "/root/vast/scott-yang/flybody/flybody/agents/learning_dmpo.py", line 324, in step
    fetches = self._step()
  File "/root/vast/scott-yang/flybody/flybody/agents/ray_distributed_dmpo.py", line 201, in _step
    return DistributionalMPOLearner._step(self)
  File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1147, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

    File "/root/vast/scott-yang/flybody/flybody/agents/learning_dmpo.py", line 274, in _step  *
        policy_loss, policy_stats = self._policy_loss_module(
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/root/vast/scott-yang/flybody/flybody/agents/losses_mpo.py", line 216, in __call__  *
        self.create_dual_variables_once(dual_variable_shape, scalar_dtype)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/once.py", line 89, in wrapper  *
        _check_no_output(wrapped(*args, **kwargs))
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/root/vast/scott-yang/flybody/flybody/agents/losses_mpo.py", line 142, in create_dual_variables_once  *
        self._log_temperature = tf.Variable(

    AssertionError:

This error message from tf.Variable is not particularly helpful since I cannot find any actual asserts to fix my code. However, the training can be run without loading the checkpoint.

Looking deeper into the checkpoint loading logic, I don't quite understand the logic at the this line, where the variable _checkpoint is created and restored but never used.

Could you help me with this issue? Any clarification and guidance is greatly appreciated!

Scott Yang.

Vision observables datatype mismatch with the rest

Hi Team,

I am trying to run the DMPO algorithm with the task and environment of vision_guided_flight in the train_dmpo_ray.py. The scripts works fine on me for the task and environment without vision, but it fails when I try to train the task that involves vision.

Specifically, when I run dmpo on vision_guided_flight, yield the following error:

(Learner pid=2663751) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::Learner.__init__() (pid=2663751, ip=10.244.10.88, actor_id=a69eff6600a70d3cf23e4fa701000000, repr=<flybody.agents.ray_distributed_dmpo.Learner object at 0x7f62530651e0>)
(Learner pid=2663751)   File "/root/vast/scott-yang/flybody/flybody/agents/ray_distributed_dmpo.py", line 138, in __init__
(Learner pid=2663751)     online_networks.init(environment_spec)
(Learner pid=2663751)   File "/root/vast/scott-yang/flybody/flybody/agents/agent_dmpo.py", line 80, in init
(Learner pid=2663751)     emb_spec = utils.create_variables(self.observation_network, [obs_spec])
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/acme/tf/utils.py", line 103, in create_variables
(Learner pid=2663751)     dummy_output = network(*add_batch_dim(dummy_input))
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method
(Learner pid=2663751)     return decorator_fn(bound_method, self, args, kwargs)
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope
(Learner pid=2663751)     return method(*args, **kwargs)
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/acme/tf/utils.py", line 144, in __call__
(Learner pid=2663751)     return self._transformation(*args, **kwargs)
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/acme/tf/utils.py", line 54, in batch_concat
(Learner pid=2663751)     return tf.concat(tree.flatten(flat_leaves), axis=-1)
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
(Learner pid=2663751)     raise e.with_traceback(filtered_tb) from None
(Learner pid=2663751)   File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 7186, in raise_from_not_ok_status
(Learner pid=2663751)     raise core._status_to_exception(e) from None  # pylint: disable=protected-access
(Learner pid=2663751) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute ConcatV2 as input #5(zero-based) was expected to be a float tensor but is a uint8 tensor [Op:ConcatV2] name: concat

Futhur debugging by printing out the observation of the environment, I found out the observation is the following:

OrderedDict([('walker/accelerometer', Array(shape=(3,), dtype=dtype('float32'), name='walker/accelerometer')), ('walker/actuator_activation', Array(shape=(0,), dtype=dtype('float32'), name='walker/actuator_activation')), ('walker/gyro', Array(shape=(3,), dtype=dtype('float32'), name='walker/gyro')), ('walker/joints_pos', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_pos')), ('walker/joints_vel', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_vel')), ('walker/left_eye', BoundedArray(shape=(32, 32, 3), dtype=dtype('uint8'), name='walker/left_eye', minimum=0, maximum=255)), ('walker/right_eye', BoundedArray(shape=(32, 32, 3), dtype=dtype('uint8'), name='walker/right_eye', minimum=0, maximum=255)), ('walker/velocimeter', Array(shape=(3,), dtype=dtype('float32'), name='walker/velocimeter')), ('walker/world_zaxis', Array(shape=(3,), dtype=dtype('float32'), name='walker/world_zaxis')), ('walker/task_input', Array(shape=(2,), dtype=dtype('float32'), name='walker/task_input'))])

of which, the observation of walker/left_eye and walker/right_eye has the dtype=uint8.

What should be the logic to convert the dtype of the vision? Is there another wrappers availiable for the vision enabled tasks?

Thank you!
Scott

Environment Creation

Hi,

If i have to create an environment and then export the same to MuJoCo in XML format, what kind of CAD tool should I use to design the environment ?

-GR

Controlling the fly legs using torque control

Hi! From your paper, I see the wings can be controlled by inputting torques. My question is, can I do the same with the leg joints?
Also, what units should the input torque be?
Thanks!

Mismatched action spec size in `flight_imitation` and `vision_guided_flight` envs

Hey guys,

I'm trying to run the docs/fly-env-examples.ipynb notebook and having issues with the flight imitation environment.

Everything up to here works fine and I can render the camera:

env = flight_imitation(wpg_pattern_path,
                       ref_flight_path,
                       terminal_com_dist=float('inf'))
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)

_ = env.reset()
pixels = env.physics.render(camera_id=1, **render_kwargs)

Running the next cell throws an error though:

random_policy = get_random_policy(env.action_spec())

frames = rollout_and_render(env, random_policy, run_until_termination=True,
                            camera_ids=1, **render_kwargs)
display_video(frames)

Raises:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 3
      1 random_policy = get_random_policy(env.action_spec())
----> 3 frames = rollout_and_render(env, random_policy, run_until_termination=True,
      4                             camera_ids=1, **render_kwargs)
      5 display_video(frames)

File [~/flybody/flybody/utils.py:33](http://localhost:8888/lab/tree/docs/flybody/utils.py#line=32), in rollout_and_render(env, policy, n_steps, run_until_termination, camera_ids, **render_kwargs)
     31     frames.append(frame)
     32     action = policy(timestep.observation)
---> 33     timestep = env.step(action)
     34 return frames

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py:52](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py#line=51), in CanonicalSpecWrapper.step(self, action)
     50 def step(self, action: types.NestedArray) -> dm_env.TimeStep:
     51   scaled_action = _scale_nested_action(action, self._action_spec, self._clip)
---> 52   return self._environment.step(scaled_action)

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py:37](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py#line=36), in SinglePrecisionWrapper.step(self, action)
     36 def step(self, action) -> dm_env.TimeStep:
---> 37   return self._convert_timestep(self._environment.step(action))

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:416](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=415), in Environment.step(self, action)
    413   self._reset_next_step = False
    414   return self.reset()
--> 416 self._hooks.before_step(self._physics_proxy, action, self._random_state)
    417 self._observation_updater.prepare_for_next_control_step()
    419 try:

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:137](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=136), in _EnvironmentHooks.before_step(self, physics, action, random_state)
    134 if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
    135   logging.info('The current episode has been running for %d steps.',
    136                self._episode_step_count)
--> 137 self._task.before_step(physics, action, random_state)
    138 for entity_hook in self._before_step.entity_hooks:
    139   entity_hook(physics, random_state)

File [~/flybody/flybody/tasks/flight_imitation.py:164](http://localhost:8888/lab/tree/docs/flybody/tasks/flight_imitation.py#line=163), in FlightImitationWBPG.before_step(self, physics, action, random_state)
    160 self._ghost.set_pose(physics, ghost_qpos[:3], ghost_qpos[3:])
    161 self._ghost.set_velocity(physics, self._ref_qvel[step, :3],
    162                          self._ref_qvel[step, 3:])
--> 164 super().before_step(physics, action, random_state)

File [~/flybody/flybody/tasks/base.py:201](http://localhost:8888/lab/tree/docs/flybody/tasks/base.py#line=200), in FruitFlyTask.before_step(self, physics, action, random_state)
    199 if self._action_corruptor is not None:
    200     action = self._action_corruptor(action, random_state)
--> 201 self._walker.apply_action(physics, action, random_state)

File [~/flybody/flybody/fruitfly/fruitfly.py:502](http://localhost:8888/lab/tree/docs/flybody/fruitfly/fruitfly.py#line=501), in FruitFly.apply_action(***failed resolving arguments***)
    500     return
    501 # Update previous action.
--> 502 self._prev_action[:] = action
    503 # Apply MuJoCo actions.
    504 ctrl = np.zeros(physics.model.nu)

ValueError: could not broadcast input array from shape (12,) into shape (11,)

Inspecting the env, it looks like it should be 12-dimensional:

env.action_spec()
BoundedArray(shape=(12,), dtype=dtype('float32'), name='head_abduct\thead_twist\thead\twing_yaw_left\twing_roll_left\twing_pitch_left\twing_yaw_right\twing_roll_right\twing_pitch_right\tabdomen_abduct\tabdomen\tuser_0', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.], maximum=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])

Further confirming this, it looks like the provided checkpoint from figshare also has a 12-dimensional action space:

flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)

zero_obs = {k: tf.zeros(v.shape, dtype=v.dtype) for k, v in env.observation_spec().items()}
act = flight_policy(zero_obs)
act.shape
(12,)

Here's a more self contained version without the rollout_and_render wrapper:

env = flight_imitation(wpg_pattern_path,
                       ref_flight_path,
                       terminal_com_dist=float('inf'))
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)

flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)


print("env.observation_spec():")
print(env.observation_spec())
print()

print("env.action_spec():")
print(env.action_spec())
print()


timestep = env.reset()
print("timestep.observation from env.reset():")
print({k: v.shape for k, v in timestep.observation.items()})
print()

action = flight_policy(timestep.observation)
print("flight_policy(timestep.observation) action output:")
print(action.shape)
print()

timestep = env.step(action)  # throws error
env.observation_spec():
OrderedDict([('walker/accelerometer', Array(shape=(3,), dtype=dtype('float32'), name='walker/accelerometer')), ('walker/actuator_activation', Array(shape=(0,), dtype=dtype('float32'), name='walker/actuator_activation')), ('walker/gyro', Array(shape=(3,), dtype=dtype('float32'), name='walker/gyro')), ('walker/joints_pos', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_pos')), ('walker/joints_vel', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_vel')), ('walker/velocimeter', Array(shape=(3,), dtype=dtype('float32'), name='walker/velocimeter')), ('walker/world_zaxis', Array(shape=(3,), dtype=dtype('float32'), name='walker/world_zaxis')), ('walker/ref_displacement', Array(shape=(6, 3), dtype=dtype('float32'), name='walker/ref_displacement')), ('walker/ref_root_quat', Array(shape=(6, 4), dtype=dtype('float32'), name='walker/ref_root_quat'))])

env.action_spec():
BoundedArray(shape=(12,), dtype=dtype('float32'), name='head_abduct\thead_twist\thead\twing_yaw_left\twing_roll_left\twing_pitch_left\twing_yaw_right\twing_roll_right\twing_pitch_right\tabdomen_abduct\tabdomen\tuser_0', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.], maximum=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])

timestep.observation from env.reset():
{'walker/accelerometer': (3,), 'walker/actuator_activation': (0,), 'walker/gyro': (3,), 'walker/joints_pos': (25,), 'walker/joints_vel': (25,), 'walker/velocimeter': (3,), 'walker/world_zaxis': (3,), 'walker/ref_displacement': (6, 3), 'walker/ref_root_quat': (6, 4)}

flight_policy(timestep.observation) action output:
(12,)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[22], line 31
     28 print(action.shape)
     29 print()
---> 31 timestep = env.step(action)  # throws error

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py:52](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py#line=51), in CanonicalSpecWrapper.step(self, action)
     50 def step(self, action: types.NestedArray) -> dm_env.TimeStep:
     51   scaled_action = _scale_nested_action(action, self._action_spec, self._clip)
---> 52   return self._environment.step(scaled_action)

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py:37](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py#line=36), in SinglePrecisionWrapper.step(self, action)
     36 def step(self, action) -> dm_env.TimeStep:
---> 37   return self._convert_timestep(self._environment.step(action))

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:416](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=415), in Environment.step(self, action)
    413   self._reset_next_step = False
    414   return self.reset()
--> 416 self._hooks.before_step(self._physics_proxy, action, self._random_state)
    417 self._observation_updater.prepare_for_next_control_step()
    419 try:

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:137](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=136), in _EnvironmentHooks.before_step(self, physics, action, random_state)
    134 if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
    135   logging.info('The current episode has been running for %d steps.',
    136                self._episode_step_count)
--> 137 self._task.before_step(physics, action, random_state)
    138 for entity_hook in self._before_step.entity_hooks:
    139   entity_hook(physics, random_state)

File [~/flybody/flybody/tasks/flight_imitation.py:164](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/tasks/flight_imitation.py#line=163), in FlightImitationWBPG.before_step(self, physics, action, random_state)
    160 self._ghost.set_pose(physics, ghost_qpos[:3], ghost_qpos[3:])
    161 self._ghost.set_velocity(physics, self._ref_qvel[step, :3],
    162                          self._ref_qvel[step, 3:])
--> 164 super().before_step(physics, action, random_state)

File [~/flybody/flybody/tasks/base.py:201](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/tasks/base.py#line=200), in FruitFlyTask.before_step(self, physics, action, random_state)
    199 if self._action_corruptor is not None:
    200     action = self._action_corruptor(action, random_state)
--> 201 self._walker.apply_action(physics, action, random_state)

File [~/flybody/flybody/fruitfly/fruitfly.py:502](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/fruitfly/fruitfly.py#line=501), in FruitFly.apply_action(***failed resolving arguments***)
    500     return
    501 # Update previous action.
--> 502 self._prev_action[:] = action
    503 # Apply MuJoCo actions.
    504 ctrl = np.zeros(physics.model.nu)

ValueError: could not broadcast input array from shape (12,) into shape (11,)

FWIW the walk_imitation env works fine, but the vision_guided_flight env crashes with the same error.

walk_imitation data format

Hi team,

I am trying to do imitation with rodent data and I was thinking to convert our STAC data to a similar format with you guys' data. However, I am not too sure about what is joint_quat, root2site, root_qpos, and root_qvel. Would it be possible to also give some guidance on your data preprocessing steps? Thanks for all the help.

This is what I got from reading the fly walking h5 file:
Dataset: trajectories/00000/joint_quat
shape: (90, 102, 4)
dtype: float32
Dataset: trajectories/00000/root2site
shape: (90, 6, 3)
dtype: float32
Dataset: trajectories/00000/root_qpos
shape: (90, 7)
dtype: float32
Dataset: trajectories/00000/root_qvel
shape: (90, 6)
dtype: float32

I saw this processing with the walker data: (

def get_walker_features(physics, mocap_joints, mocap_sites):
) is this a similar process done on the reference trajectory?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.