instadeepai / jumanji Goto Github PK
View Code? Open in Web Editor NEW๐น๏ธ A diverse suite of scalable reinforcement learning environments in JAX
Home Page: https://instadeepai.github.io/jumanji
License: Apache License 2.0
๐น๏ธ A diverse suite of scalable reinforcement learning environments in JAX
Home Page: https://instadeepai.github.io/jumanji
License: Apache License 2.0
Previously the workflow was not triggered.
No response
No response
No response
No response
No response
We should add a py.typed
file so that the mypy
type checker knows to use the type hints provided by the published jumanji
package. See this for more information.
The table in the Examples section of the README renders correctly on Github:
but incorrectly on the documentation site:
N/A
N/A
N/A
No response
Fix: Add a space between the table and the text.
As of now, when calling make
with an environment name omitting the version, the current behaviour is to fetch version v0
. This will become a problem when we start having more than one version per environment. Getting v0
doesn't make sense.
I suggest we throw an error if the version number is missing. This would simplify the code and force users to be explicit about the version they want. It is also better for reproducibility. It is slightly less user-friendly because they would need to look up the listing of the registered environments before calling make
.
An alternative solution is to change the described behaviour to fetch the latest version of an environment if the version number is omitted.
The mypy badge appear broken on the online documentation (build hosted by GitHub pages)
The CI pipeline fails at the linting step, specifically flake8
fails with the following error:
An unexpected error has occurred: CalledProcessError: command: ('/usr/bin/git', 'fetch', 'origin', '--tags')
return code: 128
expected return code: 0
stdout: (none)
stderr:
fatal: could not read Username for 'https://gitlab.com/': No such device or address
Check the log at /home/runner/.cache/pre-commit/pre-commit.log
v0.1.1
N/A
N/A
No response
This error occurs because flake8
took down their GitLab repository in favour of their GitHub repository.
However, by default, pre-commit
uses the GitLab link. Thus, we need to replace https://gitlab.com/PyCQA/flake8 with https://github.com/PyCQA/flake8
in the .pre-commit-config.yaml
. For more information, see this
env.observation_spec
is a method and creates a new spec each time it is being called. If we it a property, we can JIT a function that gets these specs.
def __init__(self):
self.observation_spec = self._make_observation_spec()
def _make_observation_spec():
return spec.Spec(...)
Make it clear to the community what style guide they should be abiding by which could reduce the number of review iterations required to get a pull request in.
Add a reference to the google style guide in the developer documentation or readme.
jumanji.make("TSP")
should default to latest version of TSP, e.g. "TSP-v1". This way, a user who sees the codebase and agrees with the current version of TSP should be able to use TSP and not have to search for what version of TSP we are at.
Currently, the BraxToJumanjiWrapper
cannot accept a Brax environment that has been wrapped by Brax's VmapWrapper
due to jax.lax.cond
not broadcasting in the case of a vmap-ed environment. A current workaround is to not use Brax's VmapWrapper', convert it to Jumanji with
BraxToJumanjiWrapperand then use Jumanji's
VmapWrapper`, this will produce the desired outcome.
To summarize,
from jumanji import wrappers as jumanji_wrappers
from brax.envs import create, wrappers as brax_wrappers
brax_env = create("ant")
# This does not work
jumanji_env = jumanji_wrappers.BraxToJumanjiWrapper(brax_wrappers.VmapWrapper(brax_env))
# This works
jumanji_env = jumanji_wrappers.VmapWrapper(jumanji_wrappers.BraxToJumanjiWrapper(brax_env))
The goal of this issue is to make the first solution work as well for consistency.
The jax.lax.cond
line in the step
function of BraxToJumanjiWrapper
should be changed to handle the case where
state.done
is a vector and not a scalar (i.e., when the Brax environment originates from a VmapWrapper).
from brax.envs import wrappers as brax_wrappers
from brax.envs import create
from jumanji.wrappers import BraxToJumanjiWrapper
brax_env = create("ant")
jumanji_env = BraxToJumanjiWrapper(brax_wrappers.VmapWrapper(brax_env))
state, timestep = jax.jit(jumanji_env.reset)(jax.random.split(jax.random.PRNGKey(0), num = 1))
action = jumanji_env.action_spec().generate_value()[None, ...]
state, timestep = jax.jit(jumanji_env.step)(state, action)
jumanji_env = jumanji_wrappers.BraxToJumanjiWrapper(brax_wrappers.VmapWrapper(brax_env))
. Calling jumanji_env.reset
and jumanji_env.step
should work without raising exceptions.Change copyright year to 2023 instead of 2022.
Connect4 will be removed in a future release (v0.2).
Raise a deprecated warning when using Connect4 as it will be removed soon.
Bug in the reset function: action_mask = jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
should be changed to action_mask = jnp.ones((BOARD_WIDTH,), dtype=bool)
.
On the documentation website, there are hyperlinks/shortcuts that do not work properly, although they are working on the readme on GitHub.
Clicking the shortcuts on the menu bar (Installation | Quickstart | ... | Reference Docs) should redirect the user to the corresponding locations on the page. There is a problem with the emojis that are sometimes part of the link (on the GitHub's readme I think) and sometimes not (on the website I believe). Also, the "contributing guidelines" link does not work on the website because there is no page hosted for it.
It seems that one has to reconcile the way GitHub renders the readme and the way mkdocs builds the doc when it comes to hyperlinks.
Jumanji specs have a generate_value
method which essentially returns zeros of the correct pytree/shape. It would be nice to be able to sample values (with an optional mask) from the action spec. This would give us random policies for free for all environments.
Add a sample
method to the specs similar to Gym.
We need to decide whether it is redundant to have both sample
and generate_value
.
A clear and concise description of any alternative solutions or features you've considered.
Add any other context or screenshots about the feature request here.
param_size
is not used so it should be removed. This might make the network builder functions lighter as some may not need the env specs anymore (e.g. for TSP).
Create a JobShop
instance generator whose optimal makespan is known. This will be useful to benchmark agents better since the best solution will be known in advance.
The generator will generate a random schedule based on a specified makespan (length of the schedule), number of machines, number of jobs, and max number of operations per job.
Point colab link to notebook in main branch
The environment factory ENV_FACTORY
in setup_train.py
is no longer used and can be removed.
When working with a dm_env.Environment
version of a Jumanji environment using the JumanjiToDMEnvWrapper
, one may need to get/set the state and key of the environment e.g. to allow planning and "restart" the environment to its previous state.
Right now, this is possible by calling wrapped_env._state
and wrapped_env._state
which should not be allowed since key and state are private attributes.
A solution would be to properly expose them as properties and to implement setters for these properties (in the common style).
JumanjiToDMEnvWrapper
as propertiesjumanji.make("Snake-6x6-v0")
), wrap it with JumanjiToDMEnvWrapper
and then get and set the corresponding state
and key
attributes without leading underscoresJumanji supports Python 3.7 so we should add a ci pipeline in GitHub Actions for running linters, tests, etc. for Python 3.7.
Create an abstract base class Viewer
which all environment viewers inherit from. This will help standardise how rendering is done across all environments.
Good morning,
Importing jumanji is making my other tests seg faults.
Without any test calling/importing code relying on JJ - if I just add a simple import jumanji - tests seg fault, if I remove the import, tests pass.
if I comment in jumanji/__init__.py
the import of the binpack sub-module (from jumanji.environments.combinatorial import binpack as _binpack
) and the associated env registrations - tests pass.
I suspect it has to do with something in jumanji\environments\__init__.py
which is loaded when importing binpack. Since there are a lot of import in the init - it's hard to find the culprit.
I can't disclose the dependencies I am using.
Thanks,
Cyprien
0.1.3
CPU/GPU
Python 3.10.8 - WSL - Ubuntu 20.04 LTS
No response
I have forked the repo and am working on making the registration of the binpack env not require any import from the env.
Instead of:
register(
id="BinPack-rand20-v0",
entry_point="jumanji.environments:BinPack",
kwargs={
"instance_generator": _binpack.instance_generator.RandomInstanceGenerator(
max_num_items=20,
max_num_ems=80,
),
"obs_num_ems": 40,
},
)
We could have:
register(
id="BinPack-rand20-v0",
entry_point="jumanji.environments:BinPack",
kwargs={
"instance_generator": {
"type": "random",
"max_num_items": 20,
"max_num_ems": 80,
},
"obs_num_ems": 40,
},
)
or,
register(
id="BinPack-rand20-v0",
entry_point="jumanji.environments:BinPack",
kwargs={
"instance_generator": "random",
"max_num_items": 20,
"max_num_ems": 80,
"obs_num_ems": 40,
},
)
The extras
field of TimeStep
can contain environment information useful for decision-making (e.g. Connect4
's current player ID) or environment metrics (e.g. BinPack's volume utilisation). There is an inconsistency in what the extras
field is used for as it is sometimes meant to be used by the algorithm and sometimes just logged as a metric.
We should move any algorithm-related information from extras to the environment observation (e.g. Connect4's observation could have another field called current_player
or something). We should update the documentation/docstrings accordingly to explicitly mention that TimeStep.extras
does not contain stuff that is meant to be observed by the agent as those should be in the observation.
TimeStep.extras
does not contain any info meant to be observedOne assumption about environment states is that they have a key
(jax random key) attribute to manage stochasticity in the environment step function. This key is then used in wrappers such as JumanjiToDMEnvWrapper
.
I am not sure of the solution to go for. I have identified two possible ways: using protocols to make it explicit that states have a key
attribute, or using some abstract State class that will have the key attribute mandatory.
What is the best way of forcing the environment's State
to have a key
attribute?
If I run import jumanji
from a console (a PyCharm console in my case), I get the following error from this line:
IPython.core.error.UsageError: Invalid GUI request 'notebook', valid ones are:dict_keys(['none', 'osx', 'tk', 'gtk', 'wx', 'qt', 'qt4', 'qt5', 'glut', 'pyglet', 'gtk3'])
I think it comes from the fact that we are checking the backend to set up the proper matplotlib backend accordingly. However, the current version seems to assume a jupyter notebook when it is a python console, thus breaking at import time.
The solution would be to set up the backend optionally, i.e. if an error is encountered, then a default backend is set. Or to figure out how to properly differentiate jupyter notebooks from something else. In any case, we should not break at import time because of rendering!
v0.1.1
No response
Linux
No response
No response
Since the dataclasses we use are mutable, some side effects may occur when working with environment State
and TimeStep
.
Use NamedTuple
instead.
We could also freeze the dataclasses to be immutable but the NamedTuple
option is preferred.
People are more likely to use Jumanji if they can quickly try it out without a lot of admin. We should have an example notebook (we previously had the anakin snake notebook but it was removed because it was outdated).
We would like Jumanji to support Python 3.10.
The Jumanji logo, badges and environment GIF do not appear on the PyPI description page. This is probably due to not exporting the images when uploading to PyPI.
v0.1.1
No response
No response
No response
No response
Sorry for the comment from out of the blue. Jumanji's sophisticated API is great, and its application to problems like TSP is really interesting.
Today, we released Pgx, a collection of JAX-based RL environments dedicated to classic board games like Go. We have implemented over 15 environments, including Backgammon, Shogi, and Go, and confirmed that they are considerably faster than existing C++/Python implementations. We also plan to implement Chess and Contract Bridge in the coming weeks.
We believe Jumanji and Pgx can complement each other as both implement JAX-based RL environments but focus on different domains. We would be happy if you could kindly mention Pgx in the README like other JAX-based RL environments if you like it. For example,
๐ฒ Pgx provides classic board game environments like Backgammon, Shogi, and Go.
๐ฒ [Pgx](https://github.com/sotetsuk/pgx) provides classic board game environments like Backgammon, Shogi, and Go.
Thanks!
The CI is broken due to the default python version
No response
No response
No response
No response
No response
At multiple places in the code, a workaround is used to check typing of chex dataclasses. This involves doing a conditional import using if TYPE_CHECKING
- we would like to avoid this if possible. It is related to this issue.
We would a solution which avoids doing the conditional import.
Use a built-in dataclass instead of a chex dataclass. Example implementation where mypy doesn't complain:
@dataclasses.dataclass(init=False)
class TimeStep(Generic[Observation]):
step_type: StepType
reward: Array
discount: Array
observation: Observation
extras: Optional[Dict]
def __init__(self, step_type: StepType, reward: Array, discount: Array, observation: Observation,
extras: Optional[Dict] = None):
self.step_type = step_type
self.reward = reward
self.discount = discount
self.observation = observation
self.extras = extras
def first(self) -> Array:
return self.step_type == StepType.FIRST
def mid(self) -> Array:
return self.step_type == StepType.MID
def last(self) -> Array:
return self.step_type == StepType.LAST
Add a hyperlink to the autogenerated docs in the section describing the environment registry and versioning.
This issue is about cleaning the type hinting in the repository. Currently, when a method (inside a given class) returns an object whose type is the class itself (e.g. in Environment
), the return type is given with quotes:
class Environment:
...
def unwrapped(self) -> "Environment":
...
According to this thread, this is needed for Python<3.7, but can be removed with future annotations from Python 3.7+.
Since Jumanji supports Python>=3.8, we could get rid of these quotes throughout the repository and use future annotations instead.
The root is quite messy with lots of files (e.g. commitlint.config.js
, mkdocs.yml
, license_header.txt
).
We could move a lot of these files to separate folders and keep only necessary files in the root directory (e.g. README
, and setup.py
).
The instance generator for the cleaner environment currently returns an array containing the initial grid. It should instead returns the full environment state to be consistent with other environments.
Return an instance of cleaner.State
in the instance generator.
jax.random.choice
seems to be slow especially when sampling without replacement. Sampling with replacement seems to be much faster, but even without replacement is probably slower than jax.random.randint
(to be verified).
Find a way to use jax.random.choice(replace=False)
as little as possible to improve environment speed.
As a first study towards this, it turns out that jax.random.choice(..., replace=True)
is faster than jax.random.categorical
for sampling with replacement. jax.random.choice(..., replace=False)
appears much slower than the other two. When sampling without replacement is needed, we still have to study what the best approach is.
Source: notebook
Alternatives to jax.random.choice(..., replace=False)
that could be considered and assessed include:
p1[i]
and p2[i]
It may be that jax.random.choice(..., replace=False)
ends up being the most optimised version. In any case, the solution may depend on how many samples we need to sample without replacement (e.g. 2 in the case of Snake).
We need to take this into account for random policies. It is likely that the random action selection influences the environment speed by a lot, hence biasing speed benchmarks.
Fix typo in line 50, that says snale instead of snake.
Jumanji supports Python 3.7 however the Protocol
imported from typing
isn't supported in Python 3.7.
v0.1.3
No response
No response
No response
Import Protocol
from typing_extension
instead of typing
.
When using AutoResetWrapper
, the sequence of timestep.step_type
returned during a rollout shows 0 values where the env has been auto reset instead of 2 values. Recall that
Move the auto reset bug: [1, 1, 1, 2, 0, 1, 1, 1]
has to be converted to [1, 1, 1, X, 1, 1, 1]
with X being 0 or 2. Right now it is 0 but it should be 2 to warn the user that the episode got terminated. This would be important e.g. if using while not timestep.last()
.
timestep = timestep.replace( # type: ignore
observation=reset_timestep.observation
)
We should display the test coverage in a badge on the README. This is important to show since well-tested API is a goal.
The current CONTRIBUTING.md is outdated and makes references to the old jumanji API (from before v0.1)
Update the document accordingly.
There is a potential for speedup in the BinPack
environment's step
method. If the computation is quite sparse, using jax.lax.map instead of jax.vmap may speed the environment up when lots of EMSs are not alive.
BinPack
's step method.If you go to: https://instadeepai.github.io/jumanji/#contributing
And click on contributing guidelines the server will return a 404.
No response
No response
No response
No response
No response
We should state explicitly that we follow the conventional commits specification in the CONTRIBUTING.md
.
In jumanji.wrappers
the conversion of jumanji.specs
to dm_env.specs
does not accept general PyTree Nodes. As mentioned in: google-deepmind/dm_env#10, this should simply be compatible.
Replace the else
statement within jumanji.wrappers.jumanji_specs_to_dm_env_specs
on line:465 from this:
def jumanji_specs_to_dm_env_specs(
spec: Spec,
) -> Union[dm_env.specs.DiscreteArray, dm_env.specs.BoundedArray, dm_env.specs.Array]:
if isinstance(spec, DiscreteArray):
...
elif:
...
else:
raise ValueError(
f"spec {spec} of type {type(spec)} is not available in a deepmind environment. "
"Please override the observation_spec or action_spec method to output spec of type "
"`dm_env.specs.Array`."
)
to something like this:
def jumanji_specs_to_dm_env_specs(
spec: Spec,
) -> Union[dm_env.specs.DiscreteArray, dm_env.specs.BoundedArray, dm_env.specs.Array]:
if isinstance(spec, DiscreteArray):
...
elif:
...
else:
try:
# Recursively call this function for nested Specs
return jax.tree_map(jumanji_specs_to_dm_env_specs, spec)
except ...:
raise ValueError(...)
Add any other context or screenshots about the feature request here.
A clear and concise description of what the problem is.
Add to the README environment table the environment speed (in steps/s).
Jumanji supports Python 3.7 however the Protocol
imported from typing
isn't supported in Python 3.7.
v0.1.3
No response
No response
No response
Import Protocol
from typing_extension
instead of typing
.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.