From e81b2976339c50c4fb1f298da01484128f8684fa Mon Sep 17 00:00:00 2001 From: Ian Osband Date: Tue, 21 Jun 2022 07:24:48 -0700 Subject: [PATCH 1/6] Renaming observation. PiperOrigin-RevId: 456253157 Change-Id: Ic71f175da32c83fd949acf0f385c4e045a0ea746 --- bsuite/environments/cartpole.py | 2 +- bsuite/environments/catch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bsuite/environments/cartpole.py b/bsuite/environments/cartpole.py index ea6133db..1c951b31 100644 --- a/bsuite/environments/cartpole.py +++ b/bsuite/environments/cartpole.py @@ -162,7 +162,7 @@ def action_spec(self): return specs.DiscreteArray(dtype=np.int, num_values=3, name='action') def observation_spec(self): - return specs.Array(shape=(1, 6), dtype=np.float32, name='state') + return specs.Array(shape=(1, 6), dtype=np.float32, name='observation') @property def observation(self) -> np.ndarray: diff --git a/bsuite/environments/catch.py b/bsuite/environments/catch.py index 5f88f4f0..5d51d5c8 100644 --- a/bsuite/environments/catch.py +++ b/bsuite/environments/catch.py @@ -99,7 +99,7 @@ def _step(self, action: int) -> dm_env.TimeStep: def observation_spec(self) -> specs.BoundedArray: """Returns the observation spec.""" return specs.BoundedArray(shape=self._board.shape, dtype=self._board.dtype, - name="board", minimum=0, maximum=1) + name="observation", minimum=0, maximum=1) def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" From 2873b9d54995cc36927639485583756fe7173b51 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Fri, 9 Sep 2022 17:15:25 -0700 Subject: [PATCH 2/6] Make this code compatible with Python 3.10. PiperOrigin-RevId: 473376926 Change-Id: Iea794a3e13d274a5694b9762c405cbae3db7cd82 --- bsuite/logging/logging_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bsuite/logging/logging_utils.py b/bsuite/logging/logging_utils.py index 9c801d87..652e5c30 100644 --- a/bsuite/logging/logging_utils.py +++ b/bsuite/logging/logging_utils.py @@ -15,7 +15,7 @@ # ============================================================================ """Read functionality for local csv-based experiments.""" -import collections +from collections import abc import copy from typing import Any, Callable, List, Mapping, Sequence, Tuple, Union @@ -72,7 +72,7 @@ def load_multiple_runs( # Convert any inputs to dictionary format. if isinstance(path_collection, six.string_types): path_collection = {path_collection: path_collection} - if not isinstance(path_collection, collections.Mapping): + if not isinstance(path_collection, abc.Mapping): path_collection = {path: path for path in path_collection} # Loop through multiple bsuite runs, and apply single_load_fn to each. From b26b429d64bebd9a507201e2328466e037b40727 Mon Sep 17 00:00:00 2001 From: DeepMind Date: Wed, 5 Oct 2022 03:46:29 -0700 Subject: [PATCH 3/6] Fix threshold in deep sea analysis so that the results match what is claimed in the colab. PiperOrigin-RevId: 479005894 Change-Id: I7418926eae648651deb6971863957cdb682caf14 --- bsuite/experiments/deep_sea/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bsuite/experiments/deep_sea/analysis.py b/bsuite/experiments/deep_sea/analysis.py index fc336dc5..3616cf8d 100644 --- a/bsuite/experiments/deep_sea/analysis.py +++ b/bsuite/experiments/deep_sea/analysis.py @@ -37,7 +37,7 @@ def _check_data(df: pd.DataFrame) -> None: def find_solution(df_in: pd.DataFrame, sweep_vars: Optional[Sequence[str]] = None, merge: bool = True, - thresh: float = 0.8, + thresh: float = 0.9, num_episodes: int = NUM_EPISODES) -> pd.DataFrame: """Find first episode that gets below thresh regret by sweep_vars.""" # Check data has the necessary columns for deep sea From 10defdaec12537e34f798a220a9af624ec124e41 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 Dec 2022 02:37:44 -0800 Subject: [PATCH 4/6] [NumPy] Remove references to deprecated NumPy type aliases. This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str). NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy. PiperOrigin-RevId: 496609197 Change-Id: I3520b3fe132a37119a0a1949a9a92e3426499607 --- bsuite/baselines/utils/sequence_test.py | 4 ++-- bsuite/environments/cartpole.py | 2 +- bsuite/environments/catch.py | 2 +- bsuite/experiments/cartpole_swingup/cartpole_swingup.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bsuite/baselines/utils/sequence_test.py b/bsuite/baselines/utils/sequence_test.py index 985a487b..a4e25b38 100644 --- a/bsuite/baselines/utils/sequence_test.py +++ b/bsuite/baselines/utils/sequence_test.py @@ -31,8 +31,8 @@ def test_buffer(self): max_sequence_length = 10 obs_shape = (3, 3) buffer = sequence.Buffer( - obs_spec=specs.Array(obs_shape, dtype=np.float), - action_spec=specs.Array((), dtype=np.int), + obs_spec=specs.Array(obs_shape, dtype=float), + action_spec=specs.Array((), dtype=int), max_sequence_length=max_sequence_length) dummy_step = dm_env.transition(observation=np.zeros(obs_shape), reward=0.) diff --git a/bsuite/environments/cartpole.py b/bsuite/environments/cartpole.py index 1c951b31..f7be258b 100644 --- a/bsuite/environments/cartpole.py +++ b/bsuite/environments/cartpole.py @@ -159,7 +159,7 @@ def _reset(self) -> dm_env.TimeStep: raise NotImplementedError('This environment implements its own auto-reset.') def action_spec(self): - return specs.DiscreteArray(dtype=np.int, num_values=3, name='action') + return specs.DiscreteArray(dtype=int, num_values=3, name='action') def observation_spec(self): return specs.Array(shape=(1, 6), dtype=np.float32, name='observation') diff --git a/bsuite/environments/catch.py b/bsuite/environments/catch.py index 5d51d5c8..c72d8d24 100644 --- a/bsuite/environments/catch.py +++ b/bsuite/environments/catch.py @@ -104,7 +104,7 @@ def observation_spec(self) -> specs.BoundedArray: def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" return specs.DiscreteArray( - dtype=np.int, num_values=len(_ACTIONS), name="action") + dtype=int, num_values=len(_ACTIONS), name="action") def _observation(self) -> np.ndarray: self._board.fill(0.) diff --git a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py index 13c3b9e9..31d690d5 100644 --- a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py +++ b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py @@ -129,7 +129,7 @@ def _reset(self) -> dm_env.TimeStep: raise NotImplementedError('This environment implements its own auto-reset.') def action_spec(self): - return specs.DiscreteArray(dtype=np.int, num_values=3, name='action') + return specs.DiscreteArray(dtype=int, num_values=3, name='action') def observation_spec(self): return specs.Array(shape=(1, 8), dtype=np.float32, name='state') From 2730a359b607594980722510e1501c0570f9cc21 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Feb 2023 00:57:42 -0800 Subject: [PATCH 5/6] Fix or ignore some pytype errors related to jnp.ndarray == jax.Array. PiperOrigin-RevId: 511422304 Change-Id: I8945990f1c970037d16d3456be0063acb91266ee --- bsuite/baselines/jax/actor_critic/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bsuite/baselines/jax/actor_critic/agent.py b/bsuite/baselines/jax/actor_critic/agent.py index 7962b849..aded414f 100644 --- a/bsuite/baselines/jax/actor_critic/agent.py +++ b/bsuite/baselines/jax/actor_critic/agent.py @@ -56,7 +56,7 @@ def __init__( # Define loss function. def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" - logits, values = network(trajectory.observations) + logits, values = network(trajectory.observations) # pytype: disable=wrong-arg-types # jax-ndarray td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, From 73de2cbdadf3d28d6a928760530fa06b31e6b1de Mon Sep 17 00:00:00 2001 From: jjshoots Date: Thu, 30 Mar 2023 18:18:26 +0100 Subject: [PATCH 6/6] fix get attr --- bsuite/utils/gym_wrapper.py | 5 +++-- bsuite/utils/wrappers.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/bsuite/utils/gym_wrapper.py b/bsuite/utils/gym_wrapper.py index 84dde88b..a0cb67ac 100644 --- a/bsuite/utils/gym_wrapper.py +++ b/bsuite/utils/gym_wrapper.py @@ -97,8 +97,9 @@ def reward_range(self) -> Tuple[float, float]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def space2spec(space: gym.Space, name: Optional[str] = None): """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. diff --git a/bsuite/utils/wrappers.py b/bsuite/utils/wrappers.py index 7e064531..d25d3027 100644 --- a/bsuite/utils/wrappers.py +++ b/bsuite/utils/wrappers.py @@ -134,8 +134,9 @@ def raw_env(self): def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def _logarithmic_logging(episode: int, ratios: Optional[Sequence[float]] = None) -> bool: @@ -173,8 +174,9 @@ def step(self, action): def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) def _small_state_to_image(shape: Sequence[int], observation: np.ndarray) -> np.ndarray: @@ -307,8 +309,9 @@ def bsuite_info(self) -> Dict[str, Any]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr) class RewardScale(environments.Environment): """Reward Scale environment wrapper.""" @@ -370,4 +373,6 @@ def bsuite_info(self) -> Dict[str, Any]: def __getattr__(self, attr): """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) + if "_env" in self.__dict__: + return getattr(self._env, attr) + return super().__getattribute__(attr)