From 76d55c91de2388f5d57bd03054fdbcf3fcd528c9 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 22 Jan 2026 15:40:44 -0800 Subject: [PATCH 1/8] initial setup --- .../isaaclab_experimental/envs/__init__.py | 5 +- .../envs/manager_based_rl_env_warp.py | 64 ++++++ .../envs/mdp/__init__.py | 16 ++ .../isaaclab_experimental/envs/mdp/rewards.py | 86 +++++++ .../managers/__init__.py | 17 ++ .../managers/manager_term_cfg.py | 40 ++++ .../managers/scene_entity_cfg.py | 47 ++++ .../isaaclab_experimental/utils/warp/utils.py | 9 + .../manager_based/__init__.py | 10 + .../manager_based/classic/__init__.py | 6 + .../classic/cartpole/__init__.py | 29 +++ .../classic/cartpole/cartpole_env_cfg.py | 209 ++++++++++++++++++ .../classic/cartpole/mdp/__init__.py | 10 + .../classic/cartpole/mdp/rewards.py | 45 ++++ 14 files changed, 592 insertions(+), 1 deletion(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py index 7c2f6ace83d..36702cbaa43 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py @@ -42,4 +42,7 @@ .. _`Task Design Workflows`: https://isaac-sim.github.io/IsaacLab/source/features/task_workflows.html """ -from .direct_rl_env_warp import DirectRLEnvWarp +from isaaclab.envs import * # noqa: F401,F403 + +from .direct_rl_env_warp import DirectRLEnvWarp # noqa: F401 +from .manager_based_rl_env_warp import ManagerBasedRLEnvWarp # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py new file mode 100644 index 00000000000..1775439440b --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager-based RL env entry point. + +This is intentionally a minimal shim to bootstrap an experimental entry point +without copying task code. The initial behavior matches the existing +`isaaclab.envs.ManagerBasedRLEnv` exactly. + +Future work will incrementally replace internals with Warp-first, graph-friendly +pipelines while keeping the manager-based task authoring model. +""" + +from __future__ import annotations + +from isaaclab_experimental.managers import RewardManager + +from isaaclab.envs import ManagerBasedEnv, ManagerBasedRLEnv +from isaaclab.envs.manager_based_rl_env_cfg import ManagerBasedRLEnvCfg +from isaaclab.managers import CommandManager, CurriculumManager, TerminationManager + + +class ManagerBasedRLEnvWarp(ManagerBasedRLEnv): + """Experimental drop-in replacement for `ManagerBasedRLEnv`. + + Notes: + - No behavior changes are introduced yet. This class exists to provide a new + Gym entry point (`isaaclab_experimental.envs:ManagerBasedRLEnvWarp`) that + we can evolve independently. + """ + + cfg: ManagerBasedRLEnvCfg + + def load_managers(self): + """Load managers but use experimental `RewardManager`. + + This keeps behavior identical to `isaaclab.envs.ManagerBasedRLEnv` while allowing the reward + pipeline to diverge in `isaaclab_experimental.managers.reward_manager`. + """ + # -- command manager (order matters: observations may depend on commands/actions) + self.command_manager = CommandManager(self.cfg.commands, self) + print("[INFO] Command Manager: ", self.command_manager) + + # call the parent class to load the managers for observations/actions/events/recorders. + ManagerBasedEnv.load_managers(self) + + # -- termination manager + self.termination_manager = TerminationManager(self.cfg.terminations, self) + print("[INFO] Termination Manager: ", self.termination_manager) + # -- reward manager (experimental fork) + self.reward_manager = RewardManager(self.cfg.rewards, self) + print("[INFO] Reward Manager: ", self.reward_manager) + # -- curriculum manager + self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self) + print("[INFO] Curriculum Manager: ", self.curriculum_manager) + + # setup the action and observation spaces for Gym + self._configure_gym_env_spaces() + + # perform events at the start of the simulation + if "startup" in self.event_manager.available_modes: + self.event_manager.apply(mode="startup") diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py new file mode 100644 index 00000000000..46d8d4c7015 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental MDP terms. + +This package forwards all stable MDP terms from :mod:`isaaclab.envs.mdp`, but overrides reward +functions with Warp-first implementations from :mod:`isaaclab_experimental.envs.mdp.rewards`. +""" + +# Forward stable MDP terms (actions/observations/terminations/etc.) +from isaaclab.envs.mdp import * # noqa: F401, F403 + +# Override reward terms with experimental implementations. +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py new file mode 100644 index 00000000000..ba0086eda71 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Common functions that can be used to enable reward functions (experimental). + +This module is intentionally minimal: it only contains reward terms that are currently +used by the experimental manager-based Cartpole task. + +All functions in this file follow the Warp-compatible reward signature expected by +`isaaclab_experimental.managers.RewardManager`: + +- ``func(env, out, **params) -> None`` + +where ``out`` is a pre-allocated Warp array of shape ``(num_envs,)`` with ``float32`` dtype. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +""" +General. +""" + + +@wp.kernel +def _is_alive_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = wp.where(terminated[i], 0.0, 1.0) + + +def is_alive(env: ManagerBasedRLEnv, out: wp.array(dtype=wp.float32)) -> None: + """Reward for being alive. Writes into ``out`` (shape: (num_envs,)).""" + terminated_wp = wp.from_torch(env.termination_manager.terminated, dtype=wp.bool) + wp.launch(kernel=_is_alive_kernel, dim=env.num_envs, inputs=[terminated_wp, out], device=env.device) + + +@wp.kernel +def _is_terminated_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = wp.where(terminated[i], 1.0, 0.0) + + +def is_terminated(env: ManagerBasedRLEnv, out) -> None: + """Penalize terminated episodes. Writes into ``out``.""" + terminated_wp = wp.from_torch(env.termination_manager.terminated, dtype=wp.bool) + wp.launch(kernel=_is_terminated_kernel, dim=env.num_envs, inputs=[terminated_wp, out], device=env.device) + + +""" +Joint penalties. +""" + + +@wp.kernel +def _sum_abs_masked_kernel( + x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) +): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + if joint_mask[j]: + s += wp.abs(x[i, j]) + out[i] = s + + +def joint_vel_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg) -> None: + """Penalize joint velocities on the articulation using an L1-kernel. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_abs_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py new file mode 100644 index 00000000000..1951f9cc5ae --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager implementations. + +This package is intended for experimental forks of manager implementations while +keeping stable task configs and the stable `isaaclab.managers` package intact. +""" + +from isaaclab.managers import * # noqa: F401,F403 + +# Override the stable implementation with the experimental fork. +from .manager_term_cfg import RewardTermCfg # noqa: F401 +from .reward_manager import RewardManager # noqa: F401 +from .scene_entity_cfg import SceneEntityCfg # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py new file mode 100644 index 00000000000..cfd6423c35f --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Configuration terms for different managers. + +This module is a passthrough to `isaaclab.managers.manager_term_cfg` except for +`RewardTermCfg`, which is overridden for the Warp-based reward manager. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import MISSING + +from isaaclab.managers.manager_term_cfg import ManagerTermBaseCfg as _ManagerTermBaseCfg +from isaaclab.managers.manager_term_cfg import * # noqa: F401,F403 +from isaaclab.utils import configclass + + +@configclass +class RewardTermCfg(_ManagerTermBaseCfg): + """Configuration for a reward term. + + The function is expected to write the (unweighted) reward values into a + pre-allocated Warp buffer provided by the manager. + + Expected signature: + + - ``func(env, out, **params) -> None`` + + where ``out`` is a Warp array of shape ``(num_envs,)`` with float32 dtype. + """ + + func: Callable[..., None] = MISSING + """The function to be called to fill the pre-allocated reward buffer.""" + + weight: float = MISSING + """The weight of the reward term.""" diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py new file mode 100644 index 00000000000..612e2695aef --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental fork of :class:`isaaclab.managers.SceneEntityCfg`. + +This adds Warp-only cached selections (e.g. a joint mask) while keeping compatibility +with the stable manager stack (which type-checks against the stable SceneEntityCfg). +""" + +from __future__ import annotations + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers.scene_entity_cfg import SceneEntityCfg as _SceneEntityCfg +from isaaclab.scene import InteractiveScene + + +class SceneEntityCfg(_SceneEntityCfg): + """Scene entity configuration with an optional Warp joint mask. + + Notes: + - `joint_mask` is intended for Warp kernels only. + """ + + joint_mask: wp.array | None = None + + def resolve(self, scene: InteractiveScene): + # run the stable resolution first (fills joint_ids/body_ids from names/regex) + super().resolve(scene) + + # Build a Warp joint mask for articulations. + entity = scene[self.name] + if not isinstance(entity, Articulation): + self.joint_mask = None + return + + # Pre-allocate a full-length mask (all True for default selection). + if self.joint_ids == slice(None): + mask_list = [True] * entity.num_joints + else: + mask_list = [False] * entity.num_joints + for idx in self.joint_ids: + mask_list[idx] = True + self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py index e9c0d51632b..272991f6ec0 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py @@ -7,12 +7,21 @@ from typing import TYPE_CHECKING +import warp as wp + from isaaclab.managers.scene_entity_cfg import SceneEntityCfg if TYPE_CHECKING: from isaaclab.envs import ManagerBasedEnv +@wp.func +def wrap_to_pi(angle: float) -> float: + """Wrap input angle (in radians) to the range [-pi, pi).""" + two_pi = 2.0 * wp.pi + return angle - two_pi * wp.floor((angle + wp.pi) / two_pi) + + def resolve_asset_cfg(cfg: dict, env: ManagerBasedEnv) -> SceneEntityCfg: asset_cfg = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py new file mode 100644 index 00000000000..7f23883e633 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental registrations for manager-based tasks. + +We intentionally only register new Gym IDs pointing at experimental entry points. +Task definitions (configs/mdp) remain in `isaaclab_tasks` to avoid duplication. +""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py new file mode 100644 index 00000000000..4781f141af4 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Classic experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py new file mode 100644 index 00000000000..7bdf0a4a289 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Cartpole balancing environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +gym.register( + id="Isaac-Cartpole-Managed-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + # Use experimental Cartpole cfg (allows isolated modifications). + "env_cfg_entry_point": ( + "isaaclab_tasks_experimental.manager_based.classic.cartpole.cartpole_env_cfg:CartpoleEnvCfg" + ), + # Point agent configs to the existing task package. + "rl_games_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": ( + "isaaclab_tasks.manager_based.classic.cartpole.agents.rsl_rl_ppo_cfg:CartpolePPORunnerCfg" + ), + "skrl_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py new file mode 100644 index 00000000000..c2560ea254f --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +import isaaclab_tasks_experimental.manager_based.classic.cartpole.mdp as mdp +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.managers import ObservationTermCfg as ObsTerm +from isaaclab.managers import TerminationTermCfg as DoneTerm +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.sim._impl.newton_manager_cfg import NewtonCfg +from isaaclab.sim._impl.solvers_cfg import MJWarpSolverCfg +from isaaclab.utils import configclass + +## +# Pre-defined configs +## +from isaaclab_assets.robots.cartpole import CARTPOLE_CFG # isort:skip + + +## +# Scene definition +## + + +@configclass +class CartpoleSceneCfg(InteractiveSceneCfg): + """Configuration for a cart-pole scene.""" + + # ground plane + # ground = AssetBaseCfg( + # prim_path="/World/ground", + # spawn=sim_utils.GroundPlaneCfg(size=(100.0, 100.0)), + # ) + + # cartpole + robot: ArticulationCfg = CARTPOLE_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + dome_light = AssetBaseCfg( + prim_path="/World/DomeLight", + spawn=sim_utils.DomeLightCfg(color=(0.9, 0.9, 0.9), intensity=500.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg(asset_name="robot", joint_names=["slider_to_cart"], scale=100.0) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + joint_pos_rel = ObsTerm(func=mdp.joint_pos_rel) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel) + + def __post_init__(self) -> None: + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + # reset + reset_cart_position = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"]), + "position_range": (-1.0, 1.0), + "velocity_range": (-0.5, 0.5), + }, + ) + + reset_pole_position = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"]), + "position_range": (-0.25 * math.pi, 0.25 * math.pi), + "velocity_range": (-0.25 * math.pi, 0.25 * math.pi), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Constant running reward + alive = RewTerm(func=mdp.is_alive, weight=1.0) + # (2) Failure penalty + terminating = RewTerm(func=mdp.is_terminated, weight=-2.0) + # (3) Primary task: keep pole upright + pole_pos = RewTerm( + func=mdp.joint_pos_target_l2, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"]), "target": 0.0}, + ) + # (4) Shaping tasks: lower cart velocity + cart_vel = RewTerm( + func=mdp.joint_vel_l1, + weight=-0.01, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"])}, + ) + # (5) Shaping tasks: lower pole angular velocity + pole_vel = RewTerm( + func=mdp.joint_vel_l1, + weight=-0.005, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"])}, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Time out + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Cart out of bounds + cart_out_of_bounds = DoneTerm( + func=mdp.joint_pos_out_of_manual_limit, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"]), "bounds": (-3.0, 3.0)}, + ) + + +## +# Environment configuration +## + + +@configclass +class CartpoleEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the cartpole environment.""" + + sim: SimulationCfg = SimulationCfg( + newton_cfg=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + nconmax=5, + ), + ) + ) + + # Scene settings + scene: CartpoleSceneCfg = CartpoleSceneCfg(num_envs=4096, env_spacing=4.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + events: EventCfg = EventCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + # Simulation settings + sim: SimulationCfg = SimulationCfg( + newton_cfg=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=5, + nconmax=3, + ls_iterations=10, + cone="pyramidal", + impratio=1, + ls_parallel=True, + integrator="implicit", + ), + num_substeps=1, + debug_mode=False, + use_cuda_graph=True, + ) + ) + + # Post initialization + def __post_init__(self) -> None: + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 5 + # viewer settings + self.viewer.eye = (8.0, 0.0, 5.0) + # simulation settings + self.sim.dt = 1 / 120 + self.sim.render_interval = self.decimation diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py new file mode 100644 index 00000000000..73b1cf4fb2c --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the cartpole environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py new file mode 100644 index 00000000000..c5043964f17 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp.utils import wrap_to_pi + +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _joint_pos_target_l2_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), + target: float, +): + i = wp.tid() + s = float(0.0) + for j in range(joint_pos.shape[1]): + if joint_mask[j]: + a = wrap_to_pi(joint_pos[i, j]) + d = a - target + s += d * d + out[i] = s + + +def joint_pos_target_l2(env: ManagerBasedRLEnv, out, target: float, asset_cfg: SceneEntityCfg) -> None: + """Penalize joint position deviation from a target value. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_joint_pos_target_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset_cfg.joint_mask, out, target], + device=env.device, + ) From fe72f549b54c55d1118137df568646ae26835a16 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Fri, 23 Jan 2026 00:54:21 -0800 Subject: [PATCH 2/8] Added timer to manager based workflow --- .../isaaclab/envs/manager_based_rl_env.py | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py index 07aa4e42f0b..0e26efebb58 100644 --- a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py +++ b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py @@ -15,6 +15,7 @@ from isaaclab.managers import CommandManager, CurriculumManager, RewardManager, TerminationManager from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.timer import Timer from .common import VecEnvStepReturn from .manager_based_env import ManagerBasedEnv @@ -149,6 +150,7 @@ def setup_manager_visualizers(self): Operations - MDP """ + @Timer(name="env_step", msg="Step took:", enable=True, format="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics and reset terminated environments. @@ -181,11 +183,13 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: for _ in range(self.cfg.decimation): self._sim_step_counter += 1 # set actions into buffers - self.action_manager.apply_action() - # set actions into simulator - self.scene.write_data_to_sim() + with Timer(name="apply_action", msg="Action processing step took:", enable=True, format="us"): + self.action_manager.apply_action() + # set actions into simulator + self.scene.write_data_to_sim() # simulate - self.sim.step(render=False) + with Timer(name="simulate", msg="Newton simulation step took:", enable=True, format="us"): + self.sim.step(render=False) self.recorder_manager.record_post_physics_decimation_step() # render between steps only if the GUI or an RTX sensor needs it # note: we assume the render interval to be the shortest accepted rendering interval. @@ -195,46 +199,47 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # update buffers at sim dt self.scene.update(dt=self.physics_dt) - # post-step: - # -- update env counters (used for curriculum generation) - self.episode_length_buf += 1 # step in current episode (per env) - self.common_step_counter += 1 # total step (common for all envs) - # -- check terminations - self.reset_buf = self.termination_manager.compute() - self.reset_terminated = self.termination_manager.terminated - self.reset_time_outs = self.termination_manager.time_outs - # -- reward computation - self.reward_buf = self.reward_manager.compute(dt=self.step_dt) - - if len(self.recorder_manager.active_terms) > 0: - # update observations for recording if needed - self.obs_buf = self.observation_manager.compute() - self.recorder_manager.record_post_step() - - # -- reset envs that terminated/timed-out and log the episode information - reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) - if len(reset_env_ids) > 0: - # trigger recorder terms for pre-reset calls - self.recorder_manager.record_pre_reset(reset_env_ids) - - self._reset_idx(reset_env_ids) - - # if sensors are added to the scene, make sure we render to reflect changes in reset - if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: - for _ in range(self.cfg.num_rerenders_on_reset): - self.sim.render() - - # trigger recorder terms for post-reset calls - self.recorder_manager.record_post_reset(reset_env_ids) - - # -- update command - self.command_manager.compute(dt=self.step_dt) - # -- step interval events - if "interval" in self.event_manager.available_modes: - self.event_manager.apply(mode="interval", dt=self.step_dt) - # -- compute observations - # note: done after reset to get the correct observations for reset envs - self.obs_buf = self.observation_manager.compute(update_history=True) + with Timer(name="post_processing", msg="Post-Processing step took:", enable=True, format="us"): + # post-step: + # -- update env counters (used for curriculum generation) + self.episode_length_buf += 1 # step in current episode (per env) + self.common_step_counter += 1 # total step (common for all envs) + # -- check terminations + self.reset_buf = self.termination_manager.compute() + self.reset_terminated = self.termination_manager.terminated + self.reset_time_outs = self.termination_manager.time_outs + # -- reward computation + self.reward_buf = self.reward_manager.compute(dt=self.step_dt) + + if len(self.recorder_manager.active_terms) > 0: + # update observations for recording if needed + self.obs_buf = self.observation_manager.compute() + self.recorder_manager.record_post_step() + + # -- reset envs that terminated/timed-out and log the episode information + reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(reset_env_ids) + + self._reset_idx(reset_env_ids) + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(reset_env_ids) + + # -- update command + self.command_manager.compute(dt=self.step_dt) + # -- step interval events + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + # -- compute observations + # note: done after reset to get the correct observations for reset envs + self.obs_buf = self.observation_manager.compute(update_history=True) # return observations, rewards, resets and extras return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras From c276eb916e232c370d367decc42bbb8272a29902 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Mon, 26 Jan 2026 23:55:42 -0800 Subject: [PATCH 3/8] manager based workflow setup --- .../isaaclab_experimental/envs/__init__.py | 9 +- .../envs/manager_based_env_warp.py | 545 ++++++++++++++++++ .../envs/manager_based_rl_env_warp.py | 389 ++++++++++++- .../envs/mdp/__init__.py | 12 +- source/isaaclab_rl/isaaclab_rl/rl_games.py | 16 +- .../isaaclab_rl/rsl_rl/vecenv_wrapper.py | 14 +- source/isaaclab_rl/isaaclab_rl/sb3.py | 12 +- source/isaaclab_rl/isaaclab_rl/skrl.py | 8 +- .../classic/cartpole/__init__.py | 2 +- 9 files changed, 964 insertions(+), 43 deletions(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py index 36702cbaa43..c352923d142 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py @@ -42,7 +42,12 @@ .. _`Task Design Workflows`: https://isaac-sim.github.io/IsaacLab/source/features/task_workflows.html """ -from isaaclab.envs import * # noqa: F401,F403 - from .direct_rl_env_warp import DirectRLEnvWarp # noqa: F401 +from .manager_based_env_warp import ManagerBasedEnvWarp # noqa: F401 from .manager_based_rl_env_warp import ManagerBasedRLEnvWarp # noqa: F401 + +__all__ = [ + "DirectRLEnvWarp", + "ManagerBasedEnvWarp", + "ManagerBasedRLEnvWarp", +] diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py new file mode 100644 index 00000000000..a1a0173b7e5 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -0,0 +1,545 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager-based base environment. + +This is a local copy of :class:`isaaclab.envs.ManagerBasedEnv` placed under +``isaaclab_experimental`` so we can evolve the manager-based workflow for Warp-first +pipelines without depending on (or subclassing) the stable env implementation. + +Behavior is intended to match the stable environment initially. +""" + +# import builtins +import contextlib +import logging +import torch +import warnings +from collections.abc import Sequence +from typing import Any + +from isaaclab.envs.common import VecEnvObs +from isaaclab.envs.manager_based_env_cfg import ManagerBasedEnvCfg +from isaaclab.envs.ui import ViewportCameraController +from isaaclab.envs.utils.io_descriptors import export_articulations_data, export_scene_data +from isaaclab.managers import ActionManager, EventManager, ObservationManager, RecorderManager +from isaaclab.scene import InteractiveScene +from isaaclab.sim import SimulationContext +from isaaclab.sim.utils import use_stage +from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.seed import configure_seed +from isaaclab.utils.timer import Timer + +# import logger +logger = logging.getLogger(__name__) + + +class ManagerBasedEnvWarp: + """The base environment for the manager-based workflow (experimental fork). + + The implementation mirrors :class:`isaaclab.envs.ManagerBasedEnv` to provide + an isolated base class for experimental Warp-based workflows. + """ + + def __init__(self, cfg: ManagerBasedEnvCfg): + """Initialize the environment. + + Args: + cfg: The configuration object for the environment. + + Raises: + RuntimeError: If a simulation context already exists. The environment must always create one + since it configures the simulation context and controls the simulation. + """ + # check that the config is valid + cfg.validate() + # store inputs to class + self.cfg = cfg + # initialize internal variables + self._is_closed = False + + # set the seed for the environment + if self.cfg.seed is not None: + self.cfg.seed = self.seed(self.cfg.seed) + else: + logger.warning("Seed not set for the environment. The environment creation may not be deterministic.") + + # create a simulation context to control the simulator + if SimulationContext.instance() is None: + # the type-annotation is required to avoid a type-checking error + # since it gets confused with Isaac Sim's SimulationContext class + self.sim: SimulationContext = SimulationContext(self.cfg.sim) + else: + # simulation context should only be created before the environment + # when in extension mode + # if not builtins.ISAAC_LAUNCHED_FROM_TERMINAL: + # raise RuntimeError("Simulation context already exists. Cannot create a new one.") + self.sim: SimulationContext = SimulationContext.instance() + + # make sure torch is running on the correct device + if "cuda" in self.device: + torch.cuda.set_device(self.device) + + # print useful information + print("[INFO]: Base environment:") + print(f"\tEnvironment device : {self.device}") + print(f"\tEnvironment seed : {self.cfg.seed}") + print(f"\tPhysics step-size : {self.physics_dt}") + print(f"\tRendering step-size : {self.physics_dt * self.cfg.sim.render_interval}") + print(f"\tEnvironment step-size : {self.step_dt}") + + if self.cfg.sim.render_interval < self.cfg.decimation: + msg = ( + f"The render interval ({self.cfg.sim.render_interval}) is smaller than the decimation " + f"({self.cfg.decimation}). Multiple render calls will happen for each environment step. " + "If this is not intended, set the render interval to be equal to the decimation." + ) + logger.warning(msg) + + # counter for simulation steps + self._sim_step_counter = 0 + + # allocate dictionary to store metrics + self.extras = {} + + # generate scene + with Timer("[INFO]: Time taken for scene creation", "scene_creation"): + # set the stage context for scene creation steps which use the stage + with use_stage(self.sim.get_initial_stage()): + self.scene = InteractiveScene(self.cfg.scene) + # attach_stage_to_usd_context() + print("[INFO]: Scene manager: ", self.scene) + + # set up camera viewport controller + # viewport is not available in other rendering modes so the function will throw a warning + # FIXME: This needs to be fixed in the future when we unify the UI functionalities even for + # non-rendering modes. + if self.sim.render_mode >= self.sim.RenderMode.PARTIAL_RENDERING: + self.viewport_camera_controller = ViewportCameraController(self, self.cfg.viewer) + else: + self.viewport_camera_controller = None + + # create event manager + # note: this is needed here (rather than after simulation play) to allow USD-related randomization events + # that must happen before the simulation starts. Example: randomizing mesh scale + self.event_manager = EventManager(self.cfg.events, self) + + # apply USD-related randomization events + if "prestartup" in self.event_manager.available_modes: + self.event_manager.apply(mode="prestartup") + + # play the simulator to activate physics handles + # note: this activates the physics simulation view that exposes TensorAPIs + # note: when started in extension mode, first call sim.reset_async() and then initialize the managers + # if builtins.ISAAC_LAUNCHED_FROM_TERMINAL is False: + print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...") + with Timer("[INFO]: Time taken for simulation start", "simulation_start"): + # since the reset can trigger callbacks which use the stage, + # we need to set the stage context here + with use_stage(self.sim.get_initial_stage()): + self.sim.reset() + # update scene to pre populate data buffers for assets and sensors. + # this is needed for the observation manager to get valid tensors for initialization. + # this shouldn't cause an issue since later on, users do a reset over all the environments so the lazy + # buffers would be reset. + self.scene.update(dt=self.physics_dt) + # add timeline event to load managers + self.load_managers() + + # extend UI elements + # we need to do this here after all the managers are initialized + # this is because they dictate the sensors and commands right now + if self.sim.has_gui() and self.cfg.ui_window_class_type is not None: + # setup live visualizers + self.setup_manager_visualizers() + self._window = self.cfg.ui_window_class_type(self, window_name="IsaacLab") + else: + # if no window, then we don't need to store the window + self._window = None + + # initialize observation buffers + self.obs_buf = {} + + # export IO descriptors if requested + if self.cfg.export_io_descriptors: + self.export_IO_descriptors() + + # show deprecation message for rerender_on_reset + if self.cfg.rerender_on_reset: + msg = ( + "\033[93m\033[1m[DEPRECATION WARNING] ManagerBasedEnvCfg.rerender_on_reset is deprecated. Use" + " ManagerBasedEnvCfg.num_rerenders_on_reset instead.\033[0m" + ) + warnings.warn( + msg, + FutureWarning, + stacklevel=2, + ) + if self.cfg.num_rerenders_on_reset == 0: + self.cfg.num_rerenders_on_reset = 1 + + def __del__(self): + """Cleanup for the environment.""" + # Suppress errors during Python shutdown to avoid noisy tracebacks + # Note: contextlib may be None during interpreter shutdown + if contextlib is not None: + with contextlib.suppress(ImportError, AttributeError, TypeError): + self.close() + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """The number of instances of the environment that are running.""" + return self.scene.num_envs + + @property + def physics_dt(self) -> float: + """The physics time-step (in s). + + This is the lowest time-decimation at which the simulation is happening. + """ + return self.cfg.sim.dt + + @property + def step_dt(self) -> float: + """The environment stepping time-step (in s). + + This is the time-step at which the environment steps forward. + """ + return self.cfg.sim.dt * self.cfg.decimation + + @property + def device(self): + """The device on which the environment is running.""" + return self.sim.device + + @property + def get_IO_descriptors(self): + """Get the IO descriptors for the environment. + + Returns: + A dictionary with keys as the group names and values as the IO descriptors. + """ + return { + "observations": self.observation_manager.get_IO_descriptors, + "actions": self.action_manager.get_IO_descriptors, + "articulations": export_articulations_data(self), + "scene": export_scene_data(self), + } + + def export_IO_descriptors(self, output_dir: str | None = None): + """Export the IO descriptors for the environment. + + Args: + output_dir: The directory to export the IO descriptors to. + """ + import os + import yaml + + IO_descriptors = self.get_IO_descriptors + + if output_dir is None: + if self.cfg.log_dir is not None: + output_dir = os.path.join(self.cfg.log_dir, "io_descriptors") + else: + raise ValueError( + "Output directory is not set. Please set the log directory using the `log_dir`" + " configuration or provide an explicit output_dir parameter." + ) + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "IO_descriptors.yaml"), "w") as f: + print(f"[INFO]: Exporting IO descriptors to {os.path.join(output_dir, 'IO_descriptors.yaml')}") + yaml.safe_dump(IO_descriptors, f) + + """ + Operations - Setup. + """ + + def load_managers(self): + """Load the managers for the environment. + + This function is responsible for creating the various managers (action, observation, + events, etc.) for the environment. Since the managers require access to physics handles, + they can only be created after the simulator is reset (i.e. played for the first time). + + .. note:: + In case of standalone application (when running simulator from Python), the function is called + automatically when the class is initialized. + + However, in case of extension mode, the user must call this function manually after the simulator + is reset. This is because the simulator is only reset when the user calls + :meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor. + + """ + # prepare the managers + # -- event manager (we print it here to make the logging consistent) + print("[INFO] Event Manager: ", self.event_manager) + # -- recorder manager + self.recorder_manager = RecorderManager(self.cfg.recorders, self) + print("[INFO] Recorder Manager: ", self.recorder_manager) + # -- action manager + self.action_manager = ActionManager(self.cfg.actions, self) + print("[INFO] Action Manager: ", self.action_manager) + # -- observation manager + self.observation_manager = ObservationManager(self.cfg.observations, self) + print("[INFO] Observation Manager:", self.observation_manager) + + # perform events at the start of the simulation + # in-case a child implementation creates other managers, the randomization should happen + # when all the other managers are created + if self.__class__ == ManagerBasedEnvWarp and "startup" in self.event_manager.available_modes: + self.event_manager.apply(mode="startup") + + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + } + + """ + Operations - MDP. + """ + + def reset( + self, seed: int | None = None, env_ids: Sequence[int] | None = None, options: dict[str, Any] | None = None + ) -> tuple[VecEnvObs, dict]: + """Resets the specified environments and returns observations. + + This function calls the :meth:`_reset_idx` function to reset the specified environments. + However, certain operations, such as procedural terrain generation, that happened during initialization + are not repeated. + + Args: + seed: The seed to use for randomization. Defaults to None, in which case the seed is not set. + env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset. + options: Additional information to specify how the environment is reset. Defaults to None. + + Note: + This argument is used for compatibility with Gymnasium environment definition. + + Returns: + A tuple containing the observations and extras. + """ + if env_ids is None: + env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device) + + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(env_ids) + + # set the seed + if seed is not None: + self.seed(seed) + + # reset state of scene + self._reset_idx(env_ids) + + # update articulation kinematics + self.scene.write_data_to_sim() + self.sim.forward() + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + + # compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + + # return observations + return self.obs_buf, self.extras + + def reset_to( + self, + state: dict[str, dict[str, dict[str, torch.Tensor]]], + env_ids: Sequence[int] | None, + seed: int | None = None, + is_relative: bool = False, + ): + """Resets specified environments to provided states. + + This function resets the environments to the provided states. The state is a dictionary + containing the state of the scene entities. Please refer to :meth:`InteractiveScene.get_state` + for the format. + + The function is different from the :meth:`reset` function as it resets the environments to specific states, + instead of using the randomization events for resetting the environments. + + Args: + state: The state to reset the specified environments to. Please refer to + :meth:`InteractiveScene.get_state` for the format. + env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset. + seed: The seed to use for randomization. Defaults to None, in which case the seed is not set. + is_relative: If set to True, the state is considered relative to the environment origins. + Defaults to False. + """ + # reset all envs in the scene if env_ids is None + if env_ids is None: + env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device) + + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(env_ids) + + # set the seed + if seed is not None: + self.seed(seed) + + self._reset_idx(env_ids) + + # set the state + self.scene.reset_to(state, env_ids, is_relative=is_relative) + + # update articulation kinematics + self.sim.forward() + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + + # compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + + # return observations + return self.obs_buf, self.extras + + def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]: + """Execute one time-step of the environment's dynamics. + + The environment steps forward at a fixed time-step, while the physics simulation is + decimated at a lower time-step. This is to ensure that the simulation is stable. These two + time-steps can be configured independently using the :attr:`ManagerBasedEnvCfg.decimation` (number of + simulation steps per environment step) and the :attr:`ManagerBasedEnvCfg.sim.dt` (physics time-step) + parameters. Based on these parameters, the environment time-step is computed as the product of the two. + + Args: + action: The actions to apply on the environment. Shape is (num_envs, action_dim). + + Returns: + A tuple containing the observations and extras. + """ + # process actions + self.action_manager.process_action(action.to(self.device)) + + self.recorder_manager.record_pre_step() + + # check if we need to do rendering within the physics loop + # note: checked here once to avoid multiple checks within the loop + is_rendering = self.sim.has_gui() or self.sim.has_rtx_sensors() + + # perform physics stepping + for _ in range(self.cfg.decimation): + self._sim_step_counter += 1 + # set actions into buffers + self.action_manager.apply_action() + # set actions into simulator + self.scene.write_data_to_sim() + # simulate + self.sim.step(render=False) + # render between steps only if the GUI or an RTX sensor needs it + # note: we assume the render interval to be the shortest accepted rendering interval. + # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior. + if self._sim_step_counter % self.cfg.sim.render_interval == 0 and is_rendering: + self.sim.render() + # update buffers at sim dt + self.scene.update(dt=self.physics_dt) + + # post-step: step interval event + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + + # -- compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + self.recorder_manager.record_post_step() + + # return observations and extras + return self.obs_buf, self.extras + + @staticmethod + def seed(seed: int = -1) -> int: + """Set the seed for the environment. + + Args: + seed: The seed for random generator. Defaults to -1. + + Returns: + The seed used for random generator. + """ + # set seed for replicator + try: + import omni.replicator.core as rep + + rep.set_global_seed(seed) + except ModuleNotFoundError: + pass + # set seed for torch and other libraries + return configure_seed(seed) + + def close(self): + """Cleanup for the environment.""" + if not self._is_closed: + # destructor is order-sensitive + del self.viewport_camera_controller + del self.action_manager + del self.observation_manager + del self.event_manager + del self.recorder_manager + del self.scene + + # self.sim.clear_all_callbacks() + self.sim.clear_instance() + + # destroy the window + if self._window is not None: + self._window = None + # update closing status + self._is_closed = True + + """ + Helper functions. + """ + + def _reset_idx(self, env_ids: Sequence[int]): + """Reset environments based on specified indices. + + Args: + env_ids: List of environment ids which must be reset + """ + # reset the internal buffers of the scene elements + self.scene.reset(env_ids) + + # apply events such as randomization for environments that need a reset + if "reset" in self.event_manager.available_modes: + env_step_count = self._sim_step_counter // self.cfg.decimation + self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=env_step_count) + + # iterate over all managers and reset them + # this returns a dictionary of information which is stored in the extras + # note: This is order-sensitive! Certain things need be reset before others. + self.extras["log"] = dict() + # -- observation manager + info = self.observation_manager.reset(env_ids) + self.extras["log"].update(info) + # -- action manager + info = self.action_manager.reset(env_ids) + self.extras["log"].update(info) + # -- event manager + info = self.event_manager.reset(env_ids) + self.extras["log"].update(info) + # -- recorder manager + info = self.recorder_manager.reset(env_ids) + self.extras["log"].update(info) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 1775439440b..7a8e7e06030 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -3,53 +3,133 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Experimental manager-based RL env entry point. +"""Experimental manager-based RL environment (Warp entry point). -This is intentionally a minimal shim to bootstrap an experimental entry point -without copying task code. The initial behavior matches the existing -`isaaclab.envs.ManagerBasedRLEnv` exactly. - -Future work will incrementally replace internals with Warp-first, graph-friendly -pipelines while keeping the manager-based task authoring model. +This module provides an experimental fork of the stable manager-based RL environment +so it can diverge (Warp-first / graph-friendly) without inheriting from the stable +`isaaclab.envs.ManagerBasedRLEnv` implementation. """ +# needed to import for allowing type-hinting: np.ndarray | None from __future__ import annotations +import gymnasium as gym +import math +import numpy as np +import torch +from collections.abc import Sequence +from typing import Any, ClassVar + from isaaclab_experimental.managers import RewardManager -from isaaclab.envs import ManagerBasedEnv, ManagerBasedRLEnv +from isaaclab.envs.common import VecEnvStepReturn from isaaclab.envs.manager_based_rl_env_cfg import ManagerBasedRLEnvCfg from isaaclab.managers import CommandManager, CurriculumManager, TerminationManager +from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.timer import Timer + +from .manager_based_env_warp import ManagerBasedEnvWarp + + +class ManagerBasedRLEnvWarp(ManagerBasedEnvWarp, gym.Env): + """The superclass for the manager-based workflow reinforcement learning-based environments. + This class inherits from :class:`ManagerBasedEnv` and implements the core functionality for + reinforcement learning-based environments. It is designed to be used with any RL + library. The class is designed to be used with vectorized environments, i.e., the + environment is expected to be run in parallel with multiple sub-environments. The + number of sub-environments is specified using the ``num_envs``. -class ManagerBasedRLEnvWarp(ManagerBasedRLEnv): - """Experimental drop-in replacement for `ManagerBasedRLEnv`. + Each observation from the environment is a batch of observations for each sub- + environments. The method :meth:`step` is also expected to receive a batch of actions + for each sub-environment. + + While the environment itself is implemented as a vectorized environment, we do not + inherit from :class:`gym.vector.VectorEnv`. This is mainly because the class adds + various methods (for wait and asynchronous updates) which are not required. + Additionally, each RL library typically has its own definition for a vectorized + environment. Thus, to reduce complexity, we directly use the :class:`gym.Env` over + here and leave it up to library-defined wrappers to take care of wrapping this + environment for their agents. + + Note: + For vectorized environments, it is recommended to **only** call the :meth:`reset` + method once before the first call to :meth:`step`, i.e. after the environment is created. + After that, the :meth:`step` function handles the reset of terminated sub-environments. + This is because the simulator does not support resetting individual sub-environments + in a vectorized environment. - Notes: - - No behavior changes are introduced yet. This class exists to provide a new - Gym entry point (`isaaclab_experimental.envs:ManagerBasedRLEnvWarp`) that - we can evolve independently. """ + is_vector_env: ClassVar[bool] = True + """Whether the environment is a vectorized environment.""" + metadata: ClassVar[dict[str, Any]] = { + "render_modes": [None, "human", "rgb_array"], + # "isaac_sim_version": get_version(), + } + """Metadata for the environment.""" + cfg: ManagerBasedRLEnvCfg + """Configuration for the environment.""" - def load_managers(self): - """Load managers but use experimental `RewardManager`. + def __init__(self, cfg: ManagerBasedRLEnvCfg, render_mode: str | None = None, **kwargs): + """Initialize the environment. - This keeps behavior identical to `isaaclab.envs.ManagerBasedRLEnv` while allowing the reward - pipeline to diverge in `isaaclab_experimental.managers.reward_manager`. + Args: + cfg: The configuration for the environment. + render_mode: The render mode for the environment. Defaults to None, which + is similar to ``"human"``. """ - # -- command manager (order matters: observations may depend on commands/actions) - self.command_manager = CommandManager(self.cfg.commands, self) + # -- counter for curriculum + self.common_step_counter = 0 + + # initialize the episode length buffer BEFORE loading the managers to use it in mdp functions. + self.episode_length_buf = torch.zeros(cfg.scene.num_envs, device=cfg.sim.device, dtype=torch.long) + + # initialize the base class to setup the scene. + super().__init__(cfg=cfg) + # store the render mode + self.render_mode = render_mode + + # initialize data and constants + # -- set the framerate of the gym video recorder wrapper so that the playback speed of the produced video matches the simulation + self.metadata["render_fps"] = 1 / self.step_dt + + print("[INFO]: Completed setting up the environment...") + + """ + Properties. + """ + + @property + def max_episode_length_s(self) -> float: + """Maximum episode length in seconds.""" + return self.cfg.episode_length_s + + @property + def max_episode_length(self) -> int: + """Maximum episode length in environment steps.""" + return math.ceil(self.max_episode_length_s / self.step_dt) + + """ + Operations - Setup. + """ + + def load_managers(self): + # note: this order is important since observation manager needs to know the command and action managers + # and the reward manager needs to know the termination manager + # -- command manager + self.command_manager: CommandManager = CommandManager(self.cfg.commands, self) print("[INFO] Command Manager: ", self.command_manager) - # call the parent class to load the managers for observations/actions/events/recorders. - ManagerBasedEnv.load_managers(self) + # call the parent class to load the managers for observations and actions. + super().load_managers() + # prepare the managers # -- termination manager self.termination_manager = TerminationManager(self.cfg.terminations, self) print("[INFO] Termination Manager: ", self.termination_manager) - # -- reward manager (experimental fork) + # -- reward manager (experimental fork; Warp-compatible rewards) self.reward_manager = RewardManager(self.cfg.rewards, self) print("[INFO] Reward Manager: ", self.reward_manager) # -- curriculum manager @@ -62,3 +142,266 @@ def load_managers(self): # perform events at the start of the simulation if "startup" in self.event_manager.available_modes: self.event_manager.apply(mode="startup") + + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + "command_manager": ManagerLiveVisualizer(manager=self.command_manager), + "termination_manager": ManagerLiveVisualizer(manager=self.termination_manager), + "reward_manager": ManagerLiveVisualizer(manager=self.reward_manager), + "curriculum_manager": ManagerLiveVisualizer(manager=self.curriculum_manager), + } + + """ + Operations - MDP + """ + + @Timer(name="env_step", msg="Step took:", enable=True, format="us") + def step(self, action: torch.Tensor) -> VecEnvStepReturn: + """Execute one time-step of the environment's dynamics and reset terminated environments. + + Unlike the :class:`ManagerBasedEnv.step` class, the function performs the following operations: + + 1. Process the actions. + 2. Perform physics stepping. + 3. Perform rendering if gui is enabled. + 4. Update the environment counters and compute the rewards and terminations. + 5. Reset the environments that terminated. + 6. Compute the observations. + 7. Return the observations, rewards, resets and extras. + + Args: + action: The actions to apply on the environment. Shape is (num_envs, action_dim). + + Returns: + A tuple containing the observations, rewards, resets (terminated and truncated) and extras. + """ + # process actions + self.action_manager.process_action(action.to(self.device)) + + self.recorder_manager.record_pre_step() + + # check if we need to do rendering within the physics loop + # note: checked here once to avoid multiple checks within the loop + is_rendering = self.sim.has_gui() or self.sim.has_rtx_sensors() + + # perform physics stepping + for _ in range(self.cfg.decimation): + self._sim_step_counter += 1 + # set actions into buffers + with Timer(name="apply_action", msg="Action processing step took:", enable=True, format="us"): + self.action_manager.apply_action() + # set actions into simulator + self.scene.write_data_to_sim() + # simulate + with Timer(name="simulate", msg="Newton simulation step took:", enable=True, format="us"): + self.sim.step(render=False) + self.recorder_manager.record_post_physics_decimation_step() + # render between steps only if the GUI or an RTX sensor needs it + # note: we assume the render interval to be the shortest accepted rendering interval. + # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior. + if self._sim_step_counter % self.cfg.sim.render_interval == 0 and is_rendering: + self.sim.render() + # update buffers at sim dt + self.scene.update(dt=self.physics_dt) + + with Timer(name="post_processing", msg="Post-Processing step took:", enable=True, format="us"): + # post-step: + # -- update env counters (used for curriculum generation) + self.episode_length_buf += 1 # step in current episode (per env) + self.common_step_counter += 1 # total step (common for all envs) + # -- check terminations + self.reset_buf = self.termination_manager.compute() + self.reset_terminated = self.termination_manager.terminated + self.reset_time_outs = self.termination_manager.time_outs + # -- reward computation + self.reward_buf = self.reward_manager.compute(dt=self.step_dt) + + if len(self.recorder_manager.active_terms) > 0: + # update observations for recording if needed + self.obs_buf = self.observation_manager.compute() + self.recorder_manager.record_post_step() + + # -- reset envs that terminated/timed-out and log the episode information + reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(reset_env_ids) + + self._reset_idx(reset_env_ids) + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(reset_env_ids) + + # -- update command + self.command_manager.compute(dt=self.step_dt) + # -- step interval events + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + # -- compute observations + # note: done after reset to get the correct observations for reset envs + self.obs_buf = self.observation_manager.compute(update_history=True) + + # return observations, rewards, resets and extras + return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras + + def render(self, recompute: bool = False) -> np.ndarray | None: + """Run rendering without stepping through the physics. + + By convention, if mode is: + + - **human**: Render to the current display and return nothing. Usually for human consumption. + - **rgb_array**: Return a numpy.ndarray with shape (x, y, 3), representing RGB values for an + x-by-y pixel image, suitable for turning into a video. + + Args: + recompute: Whether to force a render even if the simulator has already rendered the scene. + Defaults to False. + + Returns: + The rendered image as a numpy array if mode is "rgb_array". Otherwise, returns None. + + Raises: + RuntimeError: If mode is set to "rgb_data" and simulation render mode does not support it. + In this case, the simulation render mode must be set to ``RenderMode.PARTIAL_RENDERING`` + or ``RenderMode.FULL_RENDERING``. + NotImplementedError: If an unsupported rendering mode is specified. + """ + # run a rendering step of the simulator + # if we have rtx sensors, we do not need to render again sin + if not self.sim.has_rtx_sensors() and not recompute: + self.sim.render() + # decide the rendering mode + if self.render_mode == "human" or self.render_mode is None: + return None + elif self.render_mode == "rgb_array": + # check that if any render could have happened + if self.sim.render_mode.value < self.sim.RenderMode.PARTIAL_RENDERING.value: + raise RuntimeError( + f"Cannot render '{self.render_mode}' when the simulation render mode is" + f" '{self.sim.render_mode.name}'. Please set the simulation render mode to:" + f"'{self.sim.RenderMode.PARTIAL_RENDERING.name}' or '{self.sim.RenderMode.FULL_RENDERING.name}'." + " If running headless, make sure --enable_cameras is set." + ) + # create the annotator if it does not exist + if not hasattr(self, "_rgb_annotator"): + import omni.replicator.core as rep + + # create render product + self._render_product = rep.create.render_product( + self.cfg.viewer.cam_prim_path, self.cfg.viewer.resolution + ) + # create rgb annotator -- used to read data from the render product + self._rgb_annotator = rep.AnnotatorRegistry.get_annotator("rgb", device="cpu") + self._rgb_annotator.attach([self._render_product]) + # obtain the rgb data + rgb_data = self._rgb_annotator.get_data() + # convert to numpy array + rgb_data = np.frombuffer(rgb_data, dtype=np.uint8).reshape(*rgb_data.shape) + # return the rgb data + # note: initially the renerer is warming up and returns empty data + if rgb_data.size == 0: + return np.zeros((self.cfg.viewer.resolution[1], self.cfg.viewer.resolution[0], 3), dtype=np.uint8) + else: + return rgb_data[:, :, :3] + else: + raise NotImplementedError( + f"Render mode '{self.render_mode}' is not supported. Please use: {self.metadata['render_modes']}." + ) + + def close(self): + if not self._is_closed: + # destructor is order-sensitive + del self.command_manager + del self.reward_manager + del self.termination_manager + del self.curriculum_manager + # call the parent class to close the environment + super().close() + + """ + Helper functions. + """ + + def _configure_gym_env_spaces(self): + """Configure the action and observation spaces for the Gym environment.""" + # observation space (unbounded since we don't impose any limits) + self.single_observation_space = gym.spaces.Dict() + for group_name, group_term_names in self.observation_manager.active_terms.items(): + # extract quantities about the group + has_concatenated_obs = self.observation_manager.group_obs_concatenate[group_name] + group_dim = self.observation_manager.group_obs_dim[group_name] + # check if group is concatenated or not + # if not concatenated, then we need to add each term separately as a dictionary + if has_concatenated_obs: + self.single_observation_space[group_name] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=group_dim) + else: + group_term_cfgs = self.observation_manager._group_obs_term_cfgs[group_name] + term_dict = {} + for term_name, term_dim, term_cfg in zip(group_term_names, group_dim, group_term_cfgs): + low = -np.inf if term_cfg.clip is None else term_cfg.clip[0] + high = np.inf if term_cfg.clip is None else term_cfg.clip[1] + term_dict[term_name] = gym.spaces.Box(low=low, high=high, shape=term_dim) + self.single_observation_space[group_name] = gym.spaces.Dict(term_dict) + # action space (unbounded since we don't impose any limits) + action_dim = sum(self.action_manager.action_term_dim) + self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,)) + + # batch the spaces for vectorized environments + self.observation_space = gym.vector.utils.batch_space(self.single_observation_space, self.num_envs) + self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs) + + def _reset_idx(self, env_ids: Sequence[int]): + """Reset environments based on specified indices. + + Args: + env_ids: List of environment ids which must be reset + """ + # update the curriculum for environments that need a reset + self.curriculum_manager.compute(env_ids=env_ids) + # reset the internal buffers of the scene elements + self.scene.reset(env_ids) + # apply events such as randomizations for environments that need a reset + if "reset" in self.event_manager.available_modes: + env_step_count = self._sim_step_counter // self.cfg.decimation + self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=env_step_count) + + # iterate over all managers and reset them + # this returns a dictionary of information which is stored in the extras + # note: This is order-sensitive! Certain things need be reset before others. + self.extras["log"] = dict() + # -- observation manager + info = self.observation_manager.reset(env_ids) + self.extras["log"].update(info) + # -- action manager + info = self.action_manager.reset(env_ids) + self.extras["log"].update(info) + # -- rewards manager + info = self.reward_manager.reset(env_ids) + self.extras["log"].update(info) + # -- curriculum manager + info = self.curriculum_manager.reset(env_ids) + self.extras["log"].update(info) + # -- command manager + info = self.command_manager.reset(env_ids) + self.extras["log"].update(info) + # -- event manager + info = self.event_manager.reset(env_ids) + self.extras["log"].update(info) + # -- termination manager + info = self.termination_manager.reset(env_ids) + self.extras["log"].update(info) + # -- recorder manager + info = self.recorder_manager.reset(env_ids) + self.extras["log"].update(info) + + # reset the episode length buffer + self.episode_length_buf[env_ids] = 0 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py index 46d8d4c7015..74858972a57 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py @@ -9,8 +9,16 @@ functions with Warp-first implementations from :mod:`isaaclab_experimental.envs.mdp.rewards`. """ -# Forward stable MDP terms (actions/observations/terminations/etc.) -from isaaclab.envs.mdp import * # noqa: F401, F403 +# Forward stable MDP terms (actions/observations/terminations/etc.) but *exclude* rewards. +# Rewards are provided by this experimental package to keep a Warp-compatible signature +# (`func(env, out, **params) -> None`) for the experimental RewardManager. +from isaaclab.envs.mdp.actions import * # noqa: F401, F403 +from isaaclab.envs.mdp.commands import * # noqa: F401, F403 +from isaaclab.envs.mdp.curriculums import * # noqa: F401, F403 +from isaaclab.envs.mdp.events import * # noqa: F401, F403 +from isaaclab.envs.mdp.observations import * # noqa: F401, F403 +from isaaclab.envs.mdp.recorders import * # noqa: F401, F403 +from isaaclab.envs.mdp.terminations import * # noqa: F401, F403 # Override reward terms with experimental implementations. from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_rl/isaaclab_rl/rl_games.py b/source/isaaclab_rl/isaaclab_rl/rl_games.py index c182d2ff5a1..595f97fd97f 100644 --- a/source/isaaclab_rl/isaaclab_rl/rl_games.py +++ b/source/isaaclab_rl/isaaclab_rl/rl_games.py @@ -38,7 +38,7 @@ import gymnasium import torch -from isaaclab_experimental.envs import DirectRLEnvWarp +from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp from rl_games.common import env_configurations from rl_games.common.vecenv import IVecEnv @@ -80,7 +80,13 @@ class RlGamesVecEnvWrapper(IVecEnv): https://github.com/NVIDIA-Omniverse/IsaacGymEnvs """ - def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, rl_device: str, clip_obs: float, clip_actions: float): + def __init__( + self, + env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp, + rl_device: str, + clip_obs: float, + clip_actions: float, + ): """Initializes the wrapper instance. Args: @@ -98,9 +104,11 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, rl_device: str, clip_ob not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv) and not isinstance(env.unwrapped, DirectRLEnvWarp) + and not isinstance(env.unwrapped, ManagerBasedRLEnvWarp) ): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) # initialize the wrapper @@ -176,7 +184,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py b/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py index 74d57681432..910b801c7b0 100644 --- a/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py +++ b/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py @@ -8,7 +8,7 @@ from tensordict import TensorDict import warp as wp -from isaaclab_experimental.envs import DirectRLEnvWarp +from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp from rsl_rl.env import VecEnv from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv @@ -26,7 +26,11 @@ class RslRlVecEnvWrapper(VecEnv): https://github.com/leggedrobotics/rsl_rl/blob/master/rsl_rl/env/vec_env.py """ - def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp, clip_actions: float | None = None): + def __init__( + self, + env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp, + clip_actions: float | None = None, + ): """Initializes the wrapper. Note: @@ -45,9 +49,11 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp, clip_ not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv) and not isinstance(env.unwrapped, DirectRLEnvWarp) + and not isinstance(env.unwrapped, ManagerBasedRLEnvWarp) ): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) @@ -110,7 +116,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/sb3.py b/source/isaaclab_rl/isaaclab_rl/sb3.py index d560bf78e10..941735a6ef1 100644 --- a/source/isaaclab_rl/isaaclab_rl/sb3.py +++ b/source/isaaclab_rl/isaaclab_rl/sb3.py @@ -25,7 +25,7 @@ import warnings from typing import Any -from isaaclab_experimental.envs import DirectRLEnvWarp +from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import constant_fn from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn @@ -136,7 +136,9 @@ class Sb3VecEnvWrapper(VecEnv): """ - def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, fast_variant: bool = True): + def __init__( + self, env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp, fast_variant: bool = True + ): """Initialize the wrapper. Args: @@ -151,9 +153,11 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, fast_variant: bool = Tr not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv) and not isinstance(env.unwrapped, DirectRLEnvWarp) + and not isinstance(env.unwrapped, ManagerBasedRLEnvWarp) ): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) # initialize the wrapper @@ -187,7 +191,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/skrl.py b/source/isaaclab_rl/isaaclab_rl/skrl.py index 4fb9daca3aa..6dab41f37d7 100644 --- a/source/isaaclab_rl/isaaclab_rl/skrl.py +++ b/source/isaaclab_rl/isaaclab_rl/skrl.py @@ -29,7 +29,7 @@ from typing import Literal -from isaaclab_experimental.envs import DirectRLEnvWarp +from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv @@ -39,7 +39,7 @@ def SkrlVecEnvWrapper( - env: ManagerBasedRLEnv | DirectRLEnv, + env: ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch", wrapper: Literal["auto", "isaaclab", "isaaclab-single-agent", "isaaclab-multi-agent"] = "isaaclab", ): @@ -68,9 +68,11 @@ def SkrlVecEnvWrapper( not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv) and not isinstance(env.unwrapped, DirectRLEnvWarp) + and not isinstance(env.unwrapped, ManagerBasedRLEnvWarp) ): raise ValueError( - f"The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type: {type(env)}" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + f" ManagerBasedRLEnvWarp. Environment type: {type(env)}" ) # import statements according to the ML framework diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py index 7bdf0a4a289..17a4c5c03cd 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py @@ -10,7 +10,7 @@ import gymnasium as gym gym.register( - id="Isaac-Cartpole-Managed-Warp-v0", + id="Isaac-Cartpole-Warp-v0", entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", disable_env_checker=True, kwargs={ From 5f9d32bb0484c6d1ae4d35a219c29cc2320df091 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Mon, 26 Jan 2026 23:56:31 -0800 Subject: [PATCH 4/8] Made timer command line argument for train.py --- scripts/reinforcement_learning/rsl_rl/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/train.py b/scripts/reinforcement_learning/rsl_rl/train.py index 5c502fa044c..6301121859b 100644 --- a/scripts/reinforcement_learning/rsl_rl/train.py +++ b/scripts/reinforcement_learning/rsl_rl/train.py @@ -32,6 +32,7 @@ "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." ) parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.") +parser.add_argument("--timer", action="store_true", default=False, help="Enable IsaacLab Timer measurements/output.") # append RSL-RL cli arguments cli_args.add_rsl_rl_args(parser) # append AppLauncher cli args @@ -83,8 +84,8 @@ from isaaclab.utils.timer import Timer -Timer.enable = False -Timer.enable_display_output = False +Timer.enable = args_cli.timer +Timer.enable_display_output = args_cli.timer import isaaclab_tasks_experimental # noqa: F401 From 172441d7da123da4279aac7d9c8923c4336545ca Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Mon, 26 Jan 2026 23:57:04 -0800 Subject: [PATCH 5/8] Migrated reward manager to warp --- .../managers/reward_manager.py | 329 ++++++++++++++++++ .../managers/scene_entity_cfg.py | 5 +- .../classic/cartpole/mdp/rewards.py | 1 + 3 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py new file mode 100644 index 00000000000..e905bd6e25f --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py @@ -0,0 +1,329 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Reward manager for computing reward signals for a given world. + +This file is a copy of `isaaclab.managers.reward_manager` placed under +`isaaclab_experimental` so it can evolve independently. +""" + +from __future__ import annotations + +import torch +from collections.abc import Sequence +from prettytable import PrettyTable +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.managers.manager_base import ManagerBase, ManagerTermBase + +from .manager_term_cfg import RewardTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _sum_scaled_selected( + values: wp.array(dtype=wp.float32), env_ids: wp.array(dtype=wp.int64), out: wp.array(dtype=wp.float32), scale: float +): + i = wp.tid() + idx = wp.int32(env_ids[i]) + wp.atomic_add(out, 0, values[idx] * scale) + + +@wp.kernel +def _zero_selected(values: wp.array(dtype=wp.float32), env_ids: wp.array(dtype=wp.int64)): + i = wp.tid() + idx = wp.int32(env_ids[i]) + values[idx] = 0.0 + + +@wp.kernel +def _accumulate_reward_term( + term_out: wp.array(dtype=wp.float32), + reward_buf: wp.array(dtype=wp.float32), + episode_sum: wp.array(dtype=wp.float32), + step_reward: wp.array(dtype=wp.float32, ndim=2), + term_idx: int, + weight: float, + dt: float, +): + i = wp.tid() + raw = term_out[i] + weighted = raw * weight + reward_buf[i] += weighted * dt + episode_sum[i] += weighted * dt + # store weighted reward rate (matches old: value/dt) + step_reward[i, term_idx] = weighted + + +class RewardManager(ManagerBase): + """Manager for computing reward signals for a given world. + + The reward manager computes the total reward as a sum of the weighted reward terms. The reward + terms are parsed from a nested config class containing the reward manger's settings and reward + terms configuration. + + The reward terms are parsed from a config class containing the manager's settings and each term's + parameters. Each reward term should instantiate the :class:`RewardTermCfg` class. + + .. note:: + + The reward manager multiplies the reward term's ``weight`` with the time-step interval ``dt`` + of the environment. This is done to ensure that the computed reward terms are balanced with + respect to the chosen time-step interval in the environment. + + """ + + _env: ManagerBasedRLEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: ManagerBasedRLEnv): + """Initialize the reward manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, RewardTermCfg]``). + env: The environment instance. + """ + + # create buffers to parse and store terms + self._term_names: list[str] = list() + self._term_cfgs: list[RewardTermCfg] = list() + self._class_term_cfgs: list[RewardTermCfg] = list() + + # call the base class constructor (this will parse the terms config) + super().__init__(cfg, env) + # allocate persistent warp output buffer for each term (raw, unweighted) + # TODO(jichuanh): What's the best way? Can it be done in the term base class? + for term_cfg in self._term_cfgs: + term_cfg.out = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + + # prepare extra info to store individual reward term information (warp buffers) + self._episode_sums = {} + for term_name in self._term_names: + self._episode_sums[term_name] = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + # create buffer for managing reward per environment (warp buffer) + self._reward_buf = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + + # buffer which stores the current step reward rate for each term for each environment (warp buffer) + self._step_reward = wp.zeros((self.num_envs, len(self._term_names)), dtype=wp.float32, device=self.device) + + # persistent "all env ids" buffer for reset() reductions (warp buffer) + self._all_env_ids_wp = wp.array(list(range(self.num_envs)), dtype=wp.int64, device=self.device) + + # per-term scalar buffers used for reset-time logging (warp buffers) + self._episode_sum_avg = {} + for term_name in self._term_names: + self._episode_sum_avg[term_name] = wp.zeros((1,), dtype=wp.float32, device=self.device) + + def __str__(self) -> str: + """Returns: A string representation for reward manager.""" + msg = f" contains {len(self._term_names)} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = "Active Reward Terms" + table.field_names = ["Index", "Name", "Weight"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Weight"] = "r" + # add info on each term + for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): + table.add_row([index, name, term_cfg.weight]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_terms(self) -> list[str]: + """Name of active reward terms.""" + return self._term_names + + """ + Operations. + """ + + def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]: + """Returns the episodic sum of individual reward terms. + + Args: + env_ids: The environment ids for which the episodic sum of + individual reward terms is to be returned. Defaults to all the environment ids. + + Returns: + Dictionary of episodic sum of individual reward terms. + """ + extras = {} + + # resolve env_ids into a warp array view (int64 to match torch nonzero dtype) + if env_ids is None: + env_ids_wp = self._all_env_ids_wp + num_ids = self.num_envs + elif isinstance(env_ids, torch.Tensor): + env_ids_wp = wp.from_torch(env_ids, dtype=wp.int64) + num_ids = int(env_ids.numel()) + else: + env_ids_wp = wp.array(env_ids, dtype=wp.int64, device=self.device) + num_ids = len(env_ids) + + # compute and reset episodic sums + for key, episode_sum in self._episode_sums.items(): + avg_scalar = self._episode_sum_avg[key] + avg_scalar.zero_() + scale = 1.0 / (num_ids * self._env.max_episode_length_s) + wp.launch( + kernel=_sum_scaled_selected, + dim=num_ids, + inputs=[episode_sum, env_ids_wp, avg_scalar, scale], + device=self.device, + ) + wp.launch(kernel=_zero_selected, dim=num_ids, inputs=[episode_sum, env_ids_wp], device=self.device) + + extras["Episode_Reward/" + key] = wp.to_torch(avg_scalar) + # reset all the reward terms + for term_cfg in self._class_term_cfgs: + term_cfg.func.reset(env_ids=env_ids) + # return logged information + return extras + + def compute(self, dt: float) -> torch.Tensor: + """Computes the reward signal as a weighted sum of individual terms. + + This function calls each reward term managed by the class and adds them to compute the net + reward signal. It also updates the episodic sums corresponding to individual reward terms. + + Args: + dt: The time-step interval of the environment. + + Returns: + The net reward signal of shape (num_envs,). + """ + # reset computation + self._reward_buf.fill_(0.0) + self._step_reward.fill_(0.0) + # iterate over all the reward terms (Python loop; per-term math is warp) + for term_idx, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): + # skip if weight is zero (kind of a micro-optimization) + if term_cfg.weight == 0.0: + continue + # compute term into the persistent warp buffer (raw, unweighted) + # NOTE: Ensure the term defines all entries each step. This prevents stale values + # from leaking if a term only conditionally writes to `out`. + term_cfg.out.fill_(0.0) + term_cfg.func(self._env, term_cfg.out, **term_cfg.params) + # update total reward, episodic sums and step rewards in warp + wp.launch( + kernel=_accumulate_reward_term, + dim=self.num_envs, + inputs=[ + term_cfg.out, + self._reward_buf, + self._episode_sums[name], + self._step_reward, + int(term_idx), + float(term_cfg.weight), + float(dt), + ], + device=self.device, + ) + + return wp.to_torch(self._reward_buf) + + """ + Operations - Term settings. + """ + + def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): + """Sets the configuration of the specified term into the manager. + + Args: + term_name: The name of the reward term. + cfg: The configuration for the reward term. + + Raises: + ValueError: If the term name is not found. + """ + if term_name not in self._term_names: + raise ValueError(f"Reward term '{term_name}' not found.") + # set the configuration + self._term_cfgs[self._term_names.index(term_name)] = cfg + + def get_term_cfg(self, term_name: str) -> RewardTermCfg: + """Gets the configuration for the specified term. + + Args: + term_name: The name of the reward term. + + Returns: + The configuration of the reward term. + + Raises: + ValueError: If the term name is not found. + """ + if term_name not in self._term_names: + raise ValueError(f"Reward term '{term_name}' not found.") + # return the configuration + return self._term_cfgs[self._term_names.index(term_name)] + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + step_reward_torch = wp.to_torch(self._step_reward) + for idx, name in enumerate(self._term_names): + terms.append((name, [step_reward_torch[env_idx, idx].cpu().item()])) + return terms + + """ + Helper functions. + """ + + def _prepare_terms(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, RewardTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type RewardTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # check for valid weight type + if not isinstance(term_cfg.weight, (float, int)): + raise TypeError( + f"Weight for the term '{term_name}' is not of type float or int." + f" Received: '{type(term_cfg.weight)}'." + ) + # resolve common parameters + self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) + # add function to list + self._term_names.append(term_name) + self._term_cfgs.append(term_cfg) + # check if the term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._class_term_cfgs.append(term_cfg) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py index 612e2695aef..33abcd4f829 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -32,10 +32,7 @@ def resolve(self, scene: InteractiveScene): super().resolve(scene) # Build a Warp joint mask for articulations. - entity = scene[self.name] - if not isinstance(entity, Articulation): - self.joint_mask = None - return + entity: Articulation = scene[self.name] # Pre-allocate a full-length mask (all True for default selection). if self.joint_ids == slice(None): diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py index c5043964f17..fbb426751fa 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py @@ -37,6 +37,7 @@ def _joint_pos_target_l2_kernel( def joint_pos_target_l2(env: ManagerBasedRLEnv, out, target: float, asset_cfg: SceneEntityCfg) -> None: """Penalize joint position deviation from a target value. Writes into ``out``.""" asset: Articulation = env.scene[asset_cfg.name] + assert asset.data.joint_pos.shape[1] == asset_cfg.joint_mask.shape[0] wp.launch( kernel=_joint_pos_target_l2_kernel, dim=env.num_envs, From ef0e0f64e273e7350ca72570287a7326e9063b3a Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Tue, 27 Jan 2026 01:59:26 -0800 Subject: [PATCH 6/8] Added per manager profile --- .../envs/manager_based_rl_env_warp.py | 119 ++++++++++++------ 1 file changed, 82 insertions(+), 37 deletions(-) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 7a8e7e06030..89579d5422c 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -180,7 +180,14 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: A tuple containing the observations, rewards, resets (terminated and truncated) and extras. """ # process actions - self.action_manager.process_action(action.to(self.device)) + action_device = action.to(self.device) + with Timer( + name="action_manager.process_action", + msg="ActionManager.process_action took:", + enable=True, + format="us", + ): + self.action_manager.process_action(action_device) self.recorder_manager.record_pre_step() @@ -192,10 +199,15 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: for _ in range(self.cfg.decimation): self._sim_step_counter += 1 # set actions into buffers - with Timer(name="apply_action", msg="Action processing step took:", enable=True, format="us"): + with Timer( + name="action_manager.apply_action", + msg="ActionManager.apply_action took:", + enable=True, + format="us", + ): self.action_manager.apply_action() - # set actions into simulator - self.scene.write_data_to_sim() + # set actions into simulator + self.scene.write_data_to_sim() # simulate with Timer(name="simulate", msg="Newton simulation step took:", enable=True, format="us"): self.sim.step(render=False) @@ -208,46 +220,79 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # update buffers at sim dt self.scene.update(dt=self.physics_dt) - with Timer(name="post_processing", msg="Post-Processing step took:", enable=True, format="us"): - # post-step: - # -- update env counters (used for curriculum generation) - self.episode_length_buf += 1 # step in current episode (per env) - self.common_step_counter += 1 # total step (common for all envs) - # -- check terminations + # post-step: + # -- update env counters (used for curriculum generation) + self.episode_length_buf += 1 # step in current episode (per env) + self.common_step_counter += 1 # total step (common for all envs) + + # -- check terminations + with Timer( + name="termination_manager.compute", + msg="TerminationManager.compute took:", + enable=True, + format="us", + ): self.reset_buf = self.termination_manager.compute() self.reset_terminated = self.termination_manager.terminated self.reset_time_outs = self.termination_manager.time_outs - # -- reward computation + + # -- reward computation + with Timer( + name="reward_manager.compute", + msg="RewardManager.compute took:", + enable=True, + format="us", + ): self.reward_buf = self.reward_manager.compute(dt=self.step_dt) - if len(self.recorder_manager.active_terms) > 0: - # update observations for recording if needed + if len(self.recorder_manager.active_terms) > 0: + # update observations for recording if needed + with Timer( + name="observation_manager.compute", + msg="ObservationManager.compute took:", + enable=True, + format="us", + ): self.obs_buf = self.observation_manager.compute() - self.recorder_manager.record_post_step() - - # -- reset envs that terminated/timed-out and log the episode information - reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) - if len(reset_env_ids) > 0: - # trigger recorder terms for pre-reset calls - self.recorder_manager.record_pre_reset(reset_env_ids) - - self._reset_idx(reset_env_ids) - - # if sensors are added to the scene, make sure we render to reflect changes in reset - if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: - for _ in range(self.cfg.num_rerenders_on_reset): - self.sim.render() - - # trigger recorder terms for post-reset calls - self.recorder_manager.record_post_reset(reset_env_ids) - - # -- update command + self.recorder_manager.record_post_step() + + # -- reset envs that terminated/timed-out and log the episode information + reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(reset_env_ids) + + self._reset_idx(reset_env_ids) + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(reset_env_ids) + + # -- update command + with Timer( + name="command_manager.compute", + msg="CommandManager.compute took:", + enable=True, + format="us", + ): self.command_manager.compute(dt=self.step_dt) - # -- step interval events - if "interval" in self.event_manager.available_modes: - self.event_manager.apply(mode="interval", dt=self.step_dt) - # -- compute observations - # note: done after reset to get the correct observations for reset envs + + # -- step interval events + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + + # -- compute observations + # note: done after reset to get the correct observations for reset envs + with Timer( + name="observation_manager.compute_update_history", + msg="ObservationManager.compute (update_history) took:", + enable=True, + format="us", + ): self.obs_buf = self.observation_manager.compute(update_history=True) # return observations, rewards, resets and extras From cbd1b6118a75afa536c09e95bea4a206861cb717 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 29 Jan 2026 01:43:37 -0800 Subject: [PATCH 7/8] Migrated reward and action manager to warp --- .../isaaclab/envs/manager_based_rl_env.py | 119 +++-- .../envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md | 130 +++++ .../envs/manager_based_env_warp.py | 19 +- .../envs/manager_based_rl_env_warp.py | 11 +- .../envs/mdp/__init__.py | 7 +- .../envs/mdp/actions/__init__.py | 13 + .../envs/mdp/actions/actions_cfg.py | 42 ++ .../envs/mdp/actions/joint_actions.py | 272 ++++++++++ .../managers/__init__.py | 1 + .../managers/action_manager.py | 485 ++++++++++++++++++ .../managers/manager_base.py | 435 ++++++++++++++++ 11 files changed, 1487 insertions(+), 47 deletions(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py diff --git a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py index 0e26efebb58..10aa9b84b25 100644 --- a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py +++ b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py @@ -171,7 +171,14 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: A tuple containing the observations, rewards, resets (terminated and truncated) and extras. """ # process actions - self.action_manager.process_action(action.to(self.device)) + action_device = action.to(self.device) + with Timer( + name="action_manager.process_action", + msg="ActionManager.process_action took:", + enable=True, + format="us", + ): + self.action_manager.process_action(action_device) self.recorder_manager.record_pre_step() @@ -183,10 +190,15 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: for _ in range(self.cfg.decimation): self._sim_step_counter += 1 # set actions into buffers - with Timer(name="apply_action", msg="Action processing step took:", enable=True, format="us"): + with Timer( + name="action_manager.apply_action", + msg="ActionManager.apply_action took:", + enable=True, + format="us", + ): self.action_manager.apply_action() - # set actions into simulator - self.scene.write_data_to_sim() + # set actions into simulator + self.scene.write_data_to_sim() # simulate with Timer(name="simulate", msg="Newton simulation step took:", enable=True, format="us"): self.sim.step(render=False) @@ -199,46 +211,79 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # update buffers at sim dt self.scene.update(dt=self.physics_dt) - with Timer(name="post_processing", msg="Post-Processing step took:", enable=True, format="us"): - # post-step: - # -- update env counters (used for curriculum generation) - self.episode_length_buf += 1 # step in current episode (per env) - self.common_step_counter += 1 # total step (common for all envs) - # -- check terminations + # post-step: + # -- update env counters (used for curriculum generation) + self.episode_length_buf += 1 # step in current episode (per env) + self.common_step_counter += 1 # total step (common for all envs) + + # -- check terminations + with Timer( + name="termination_manager.compute", + msg="TerminationManager.compute took:", + enable=True, + format="us", + ): self.reset_buf = self.termination_manager.compute() self.reset_terminated = self.termination_manager.terminated self.reset_time_outs = self.termination_manager.time_outs - # -- reward computation + + # -- reward computation + with Timer( + name="reward_manager.compute", + msg="RewardManager.compute took:", + enable=True, + format="us", + ): self.reward_buf = self.reward_manager.compute(dt=self.step_dt) - if len(self.recorder_manager.active_terms) > 0: - # update observations for recording if needed + if len(self.recorder_manager.active_terms) > 0: + # update observations for recording if needed + with Timer( + name="observation_manager.compute", + msg="ObservationManager.compute took:", + enable=True, + format="us", + ): self.obs_buf = self.observation_manager.compute() - self.recorder_manager.record_post_step() - - # -- reset envs that terminated/timed-out and log the episode information - reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) - if len(reset_env_ids) > 0: - # trigger recorder terms for pre-reset calls - self.recorder_manager.record_pre_reset(reset_env_ids) - - self._reset_idx(reset_env_ids) - - # if sensors are added to the scene, make sure we render to reflect changes in reset - if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: - for _ in range(self.cfg.num_rerenders_on_reset): - self.sim.render() - - # trigger recorder terms for post-reset calls - self.recorder_manager.record_post_reset(reset_env_ids) - - # -- update command + self.recorder_manager.record_post_step() + + # -- reset envs that terminated/timed-out and log the episode information + reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(reset_env_ids) + + self._reset_idx(reset_env_ids) + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.has_rtx_sensors() and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(reset_env_ids) + + # -- update command + with Timer( + name="command_manager.compute", + msg="CommandManager.compute took:", + enable=True, + format="us", + ): self.command_manager.compute(dt=self.step_dt) - # -- step interval events - if "interval" in self.event_manager.available_modes: - self.event_manager.apply(mode="interval", dt=self.step_dt) - # -- compute observations - # note: done after reset to get the correct observations for reset envs + + # -- step interval events + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + + # -- compute observations + # note: done after reset to get the correct observations for reset envs + with Timer( + name="observation_manager.compute_update_history", + msg="ObservationManager.compute (update_history) took:", + enable=True, + format="us", + ): self.obs_buf = self.observation_manager.compute(update_history=True) # return observations, rewards, resets and extras diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md b/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md new file mode 100644 index 00000000000..4023902c828 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md @@ -0,0 +1,130 @@ +# Manager-based → Warp-first migration plan (experimental) + +This doc captures the incremental migration plan to make the **manager-based workflow** (config-driven managers) become **Warp-first and CUDA-graph-friendly**, while keeping the same external behavior/API as the stable manager-based environments. + +Scope: start with **Cartpole (manager-based)** as the pilot task. + +## Goals + +- Preserve the manager-based authoring model (tasks defined via config + MDP terms). +- Keep Gym API behavior the same as the stable manager-based envs. +- Make the core step/reset loop **graph-capturable** (fixed launch topology, persistent buffers, mask-based subset operations). +- Avoid touching stable code by iterating inside `isaaclab_experimental` / `isaaclab_tasks_experimental`. + +## Current state (what exists already) + +- **Experimental env entry point**: `isaaclab_experimental.envs:ManagerBasedRLEnvWarp` +- **Experimental Cartpole config + mdp**: `isaaclab_tasks_experimental.manager_based.classic.cartpole.*` +- **First manager fork**: `isaaclab_experimental.managers.RewardManager` (Warp-backed buffers / kernels) +- **Action path (Cartpole-minimal)**: `isaaclab_experimental.managers.ActionManager` + `isaaclab_experimental.envs.mdp.actions` + - Warp-first manager boundary (`process_action` consumes `wp.array`; may temporarily accept `torch.Tensor` and convert via `wp.from_torch`) + - Mask-based reset API (preferred for capture): `reset(mask: wp.array | torch.Tensor | None)` + +## Phased migration (minimal + incremental) + +### Phase 0 — Baseline experimental entry points (no behavior change) + +What: +- Register new Gym IDs that point at experimental env entry points. +- Keep task configs stable unless explicitly copied for isolation. + +Why: +- Allows iteration without breaking stable tasks. + +Deliverables: +- `isaaclab_experimental.envs:ManagerBasedRLEnvWarp` +- `Isaac-…-v0` IDs under `isaaclab_tasks_experimental` + +### Phase 1 — Term-level Warp (keep Python managers, keep Torch-facing API) + +What: +- Introduce Warp implementations for *select* MDP terms (Cartpole rewards/obs/events) while keeping: + - manager orchestration in Python + - env returns (`obs`, `rew`, `terminated`, `truncated`) as **torch.Tensor** +- Use `wp.array` buffers internally and expose `torch` views via `wp.to_torch(...)` at boundaries. + +Why: +- Lets you validate Warp math + data plumbing without rewriting the entire manager framework. + +Typical changes: +- Add `out` buffers to term cfgs (or manager-owned persistent outputs). +- Convert term functions from “return torch” → “write into wp.array”. + +Cartpole focus: +- Rewards: pole angle term, alive/terminated terms, etc. +- Observations: joint pos/vel relative +- Events: reset by offset (mask-based subset) + +### Phase 2 — Manager-level Warp buffers (still Python-loop scheduling) + +What: +- Keep manager iteration in Python, but move all per-env buffers to Warp: + - reward accumulation buffers + - termination buffers + - (optionally) action/observation buffers +- Replace torch ops like `nonzero()`, `torch.mean(...)`, per-term tensor math with Warp kernels. + +Why: +- Removes Torch from the hot-path while keeping the overall structure intact. + +Deliverables (pilot): +- Warp-backed `RewardManager` (done/ongoing) +- Warp-backed `ActionManager` (Cartpole-minimal; mask-based reset; optional Torch shim) +- Next candidates: `TerminationManager`, `EventManager` (mask-based reset/interval), `ObservationManager` + +Notes (graph/capture): +- `wp.from_torch(...)` creates a lightweight Warp wrapper around the Torch tensor memory, but you still pay Python-side overhead per call. + For CUDA graph capture, prefer **persistent buffers** (stable pointers) and update them in-place, then pass the persistent `wp.array` + through the manager boundary. This is the same caveat noted in `DirectRLEnvWarp.step`. + +Notes (Torch → Warp porting): +- Torch implementations often “broadcast” per-joint constants into `(num_envs, action_dim)` tensors for convenience. + In Warp-first ports, prefer keeping these as **constant per-joint buffers** (e.g. `(action_dim,)` for `scale/offset`, + `(action_dim, 2)` for `clip`) and index by `j` inside kernels. This avoids redundant per-env storage and extra broadcast kernels, + while preserving behavior. + +### Phase 3 — Dependency surfacing and hybrid handling + +What: +- Identify and isolate subsystems that still create Torch buffers internally (common examples: contact sensors, some recorders). +- For each dependency: + - either keep as Torch “edge” (temporarily), or + - create Warp-first equivalents / alternate codepaths + +Why: +- Some dependencies are not purely “MDP math” and need dedicated rewrites for graphability. + +### Phase 4 — Graph-friendly orchestrator rewrite (fixed topology + masks) + +What: +- Replace the dynamic parts of the env `step()`/`reset()` control flow: + - eliminate dynamic indexing patterns (e.g., `nonzero()` → env-id lists) + - use **boolean masks** (`wp.array(dtype=wp.bool)`) and kernels that apply to subsets + - ensure persistent buffers are allocated once and reused + - ensure launch order is stable and capture-ready + +Why: +- CUDA graph capture requires stable execution topology. + +Key design rules: +- **No per-step Python branching on data-dependent indices** (or keep it outside capture). +- Prefer `mask`-based APIs where possible (e.g., scene reset supports mask). +- Maintain one-time allocations; no shape changes. + +### Phase 5 — Cleanup + consolidation + +What: +- Remove transitional Torch shims and duplication where no longer needed. +- Optionally add a stable public entry point once the experimental path is validated. + +## Practical “copy vs reuse” policy + +- **Copy** into experimental when you expect semantic changes (Cartpole config/mdp, selected managers). +- **Reuse** stable implementations for everything else until it becomes a blocker. +- Prefer one fork at a time (e.g., start with `RewardManager`, then termination, then events). + +## Suggested next steps (Cartpole) + +- Keep Cartpole task config isolated under `isaaclab_tasks_experimental`. +- Continue stabilizing the experimental `RewardManager` interface (decide: term returns vs term writes). +- Add the next minimal manager fork: `TerminationManager` using Warp buffers (still return torch views). diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py index a1a0173b7e5..d5d9f01383a 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -20,11 +20,14 @@ from collections.abc import Sequence from typing import Any +import warp as wp +from isaaclab_experimental.managers import ActionManager + from isaaclab.envs.common import VecEnvObs from isaaclab.envs.manager_based_env_cfg import ManagerBasedEnvCfg from isaaclab.envs.ui import ViewportCameraController from isaaclab.envs.utils.io_descriptors import export_articulations_data, export_scene_data -from isaaclab.managers import ActionManager, EventManager, ObservationManager, RecorderManager +from isaaclab.managers import EventManager, ObservationManager, RecorderManager from isaaclab.scene import InteractiveScene from isaaclab.sim import SimulationContext from isaaclab.sim.utils import use_stage @@ -433,7 +436,13 @@ def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]: A tuple containing the observations and extras. """ # process actions - self.action_manager.process_action(action.to(self.device)) + action_device = action.to(self.device) + if action_device.dtype != torch.float32: + action_device = action_device.float() + if not action_device.is_contiguous(): + action_device = action_device.contiguous() + action_wp = wp.from_torch(action_device, dtype=wp.float32) + self.action_manager.process_action(action_wp) self.recorder_manager.record_pre_step() @@ -535,7 +544,11 @@ def _reset_idx(self, env_ids: Sequence[int]): info = self.observation_manager.reset(env_ids) self.extras["log"].update(info) # -- action manager - info = self.action_manager.reset(env_ids) + # TODO(jichuanh): mask should be natively provided + mask = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + mask[env_ids] = True + mask_wp = wp.from_torch(mask, dtype=wp.bool) + info = self.action_manager.reset(mask=mask_wp) self.extras["log"].update(info) # -- event manager info = self.event_manager.reset(env_ids) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 89579d5422c..82db14efa77 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -20,6 +20,7 @@ from collections.abc import Sequence from typing import Any, ClassVar +import warp as wp from isaaclab_experimental.managers import RewardManager from isaaclab.envs.common import VecEnvStepReturn @@ -180,14 +181,14 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: A tuple containing the observations, rewards, resets (terminated and truncated) and extras. """ # process actions - action_device = action.to(self.device) + action_wp = wp.from_torch(action) with Timer( name="action_manager.process_action", msg="ActionManager.process_action took:", enable=True, format="us", ): - self.action_manager.process_action(action_device) + self.action_manager.process_action(action_wp) self.recorder_manager.record_pre_step() @@ -427,7 +428,11 @@ def _reset_idx(self, env_ids: Sequence[int]): info = self.observation_manager.reset(env_ids) self.extras["log"].update(info) # -- action manager - info = self.action_manager.reset(env_ids) + # TODO(jichuanh): mask should be natively provided + mask = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + mask[env_ids] = True + mask_wp = wp.from_torch(mask, dtype=wp.bool) + info = self.action_manager.reset(mask=mask_wp) self.extras["log"].update(info) # -- rewards manager info = self.reward_manager.reset(env_ids) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py index 74858972a57..3c00e157b6d 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py @@ -9,10 +9,8 @@ functions with Warp-first implementations from :mod:`isaaclab_experimental.envs.mdp.rewards`. """ -# Forward stable MDP terms (actions/observations/terminations/etc.) but *exclude* rewards. -# Rewards are provided by this experimental package to keep a Warp-compatible signature -# (`func(env, out, **params) -> None`) for the experimental RewardManager. -from isaaclab.envs.mdp.actions import * # noqa: F401, F403 +# Forward stable MDP terms (commands/observations/terminations/etc.) but *exclude* rewards and actions. +# Rewards and actions are provided by this experimental package to keep Warp-first execution. from isaaclab.envs.mdp.commands import * # noqa: F401, F403 from isaaclab.envs.mdp.curriculums import * # noqa: F401, F403 from isaaclab.envs.mdp.events import * # noqa: F401, F403 @@ -21,4 +19,5 @@ from isaaclab.envs.mdp.terminations import * # noqa: F401, F403 # Override reward terms with experimental implementations. +from .actions import * # noqa: F401, F403 from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py new file mode 100644 index 00000000000..283805a279f --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental action terms (minimal). + +Only the action configs/terms currently required by the experimental manager-based Cartpole task +are provided here. +""" + +from .actions_cfg import * # noqa: F401, F403 +from .joint_actions import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py new file mode 100644 index 00000000000..fa75f69d045 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Action term configuration (experimental, minimal). + +This module mirrors the stable :mod:`isaaclab.envs.mdp.actions.actions_cfg` but only keeps what +the experimental Cartpole task needs. +""" + +from dataclasses import MISSING + +from isaaclab_experimental.managers.action_manager import ActionTerm, ActionTermCfg + +from isaaclab.utils import configclass + +from . import joint_actions + + +@configclass +class JointActionCfg(ActionTermCfg): + """Configuration for the base joint action term.""" + + joint_names: list[str] = MISSING + """List of joint names or regex expressions that the action will be mapped to.""" + + scale: float | dict[str, float] = 1.0 + """Scale factor for the action (float or dict of regex expressions). Defaults to 1.0.""" + + offset: float | dict[str, float] = 0.0 + """Offset factor for the action (float or dict of regex expressions). Defaults to 0.0.""" + + preserve_order: bool = False + """Whether to preserve the order of the joint names in the action output. Defaults to False.""" + + +@configclass +class JointEffortActionCfg(JointActionCfg): + """Configuration for the joint effort action term.""" + + class_type: type[ActionTerm] = joint_actions.JointEffortAction diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py new file mode 100644 index 00000000000..e4e8a039eff --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.managers.action_manager import ActionTerm + +import isaaclab.utils.string as string_utils +from isaaclab.assets.articulation import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + from isaaclab.envs.utils.io_descriptors import GenericActionIODescriptor + + from . import actions_cfg + +# import logger +logger = logging.getLogger(__name__) + + +@wp.kernel +def _process_joint_actions_kernel( + # input + actions: wp.array(dtype=wp.float32, ndim=2), + action_offset: int, + # params + scale: wp.array(dtype=wp.float32), + offset: wp.array(dtype=wp.float32), + clip: wp.array(dtype=wp.float32, ndim=2), + # output + raw_out: wp.array(dtype=wp.float32, ndim=2), + processed_out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, j = wp.tid() + col = action_offset + j + + a = actions[env_id, col] + raw_out[env_id, j] = a + + x = a * scale[j] + offset[j] + low = clip[j, 0] + high = clip[j, 1] + if x < low: + x = low + if x > high: + x = high + processed_out[env_id, j] = x + + +@wp.kernel +def _set_clip_1d_to_2d( + clip_low: wp.array(dtype=wp.float32), + clip_high: wp.array(dtype=wp.float32), + out: wp.array(dtype=wp.float32, ndim=2), +): + j = wp.tid() + out[j, 0] = clip_low[j] + out[j, 1] = clip_high[j] + + +@wp.kernel +def _zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): + env_id, j = wp.tid() + if mask[env_id]: + values[env_id, j] = 0.0 + + +class JointAction(ActionTerm): + r"""Base class for joint actions. + + This action term performs pre-processing of the raw actions using affine transformations (scale and offset). + These transformations can be configured to be applied to a subset of the articulation's joints. + + Mathematically, the action term is defined as: + + .. math:: + + \text{action} = \text{offset} + \text{scaling} \times \text{input action} + + where :math:`\text{action}` is the action that is sent to the articulation's actuated joints, :math:`\text{offset}` + is the offset applied to the input action, :math:`\text{scaling}` is the scaling applied to the input + action, and :math:`\text{input action}` is the input action from the user. + + Based on above, this kind of action transformation ensures that the input and output actions are in the same + units and dimensions. The child classes of this action term can then map the output action to a specific + desired command of the articulation's joints (e.g. position, velocity, etc.). + """ + + cfg: actions_cfg.JointActionCfg + """The configuration of the action term.""" + _asset: Articulation + """The articulation asset on which the action term is applied.""" + _scale: wp.array + """The scaling factor applied to the input action.""" + _offset: wp.array + """The offset applied to the input action.""" + _clip: wp.array + """The clip applied to the input action.""" + + def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None: + # initialize the action term + super().__init__(cfg, env) + + # resolve the joints over which the action term is applied + _, self._joint_names, self._joint_ids = self._asset.find_joints( + self.cfg.joint_names, preserve_order=self.cfg.preserve_order + ) + self._num_joints = len(self._joint_ids) + # log the resolved joint names for debugging + logger.info( + f"Resolved joint names for the action term {self.__class__.__name__}:" + f" {self._joint_names} [{self._joint_ids}]" + ) + + # Avoid indexing across all joints for efficiency + if self._num_joints == self._asset.num_joints and not self.cfg.preserve_order: + self._joint_ids = slice(None) + + # create tensors for raw and processed actions (Warp) + self._raw_actions = wp.zeros((self.num_envs, self.action_dim), dtype=wp.float32, device=self.device) + self._processed_actions = wp.zeros_like(self.raw_actions) + + # parse scale + if isinstance(cfg.scale, (float, int)): + self._scale = wp.array([float(cfg.scale)] * self.action_dim, dtype=wp.float32, device=self.device) + elif isinstance(cfg.scale, dict): + scale_per_joint = [1.0] * self.action_dim + # resolve the dictionary config + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.scale, self._joint_names) + for idx, value in zip(index_list, value_list): + scale_per_joint[idx] = float(value) + self._scale = wp.array(scale_per_joint, dtype=wp.float32, device=self.device) + else: + raise ValueError(f"Unsupported scale type: {type(cfg.scale)}. Supported types are float and dict.") + + # parse offset + if isinstance(cfg.offset, (float, int)): + self._offset = wp.array([float(cfg.offset)], dtype=wp.float32, device=self.device) + elif isinstance(cfg.offset, dict): + offset_per_joint = [0.0] * self.action_dim + # resolve the dictionary config + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.offset, self._joint_names) + for idx, value in zip(index_list, value_list): + offset_per_joint[idx] = float(value) + self._offset = wp.array(offset_per_joint, dtype=wp.float32, device=self.device) + else: + raise ValueError(f"Unsupported offset type: {type(cfg.offset)}. Supported types are float and dict.") + + # parse clip + clip_low = [-float("inf")] * self.action_dim + clip_high = [float("inf")] * self.action_dim + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + for idx, value in zip(index_list, value_list): + clip_low[idx] = float(value[0]) + clip_high[idx] = float(value[1]) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") + + clip_low_vec = wp.array(clip_low, dtype=wp.float32, device=self.device) + clip_high_vec = wp.array(clip_high, dtype=wp.float32, device=self.device) + self._clip = wp.zeros((self.action_dim, 2), dtype=wp.float32, device=self.device) + wp.launch( + kernel=_set_clip_1d_to_2d, + dim=self.action_dim, + inputs=[clip_low_vec, clip_high_vec, self._clip], + device=self.device, + ) + + """ + Properties. + """ + + @property + def action_dim(self) -> int: + return self._num_joints + + @property + def raw_actions(self) -> wp.array: + return self._raw_actions + + @property + def processed_actions(self) -> wp.array: + return self._processed_actions + + @property + def IO_descriptor(self) -> GenericActionIODescriptor: + """The IO descriptor of the action term. + + This descriptor is used to describe the action term of the joint action. + It adds the following information to the base descriptor: + - joint_names: The names of the joints. + - scale: The scale of the action term. + - offset: The offset of the action term. + - clip: The clip of the action term. + + Returns: + The IO descriptor of the action term. + """ + super().IO_descriptor + self._IO_descriptor.shape = (self.action_dim,) + self._IO_descriptor.dtype = str(self.raw_actions.dtype) + self._IO_descriptor.action_type = "JointAction" + self._IO_descriptor.joint_names = self._joint_names + self._IO_descriptor.scale = self._scale + # This seems to be always [4xNum_joints] IDK why. Need to check. + if isinstance(self._offset, wp.array): + self._IO_descriptor.offset = self._offset.numpy().tolist() + else: + self._IO_descriptor.offset = None + # FIXME: This is not correct. Add list support. + if self.cfg.clip is not None: + if isinstance(self._clip, wp.array): + self._IO_descriptor.clip = self._clip.numpy().tolist() + else: + self._IO_descriptor.clip = None + else: + self._IO_descriptor.clip = None + return self._IO_descriptor + + """ + Operations. + """ + + def process_actions(self, actions: wp.array, action_offset: int = 0): + wp.launch( + kernel=_process_joint_actions_kernel, + dim=(self.num_envs, self.action_dim), + inputs=[ + actions, + int(action_offset), + self._scale, + self._offset, + self._clip, + self._raw_actions, + self._processed_actions, + ], + device=self.device, + ) + + def reset(self, mask: wp.array | None = None) -> None: + """Resets the action term (mask-based).""" + if mask is None: + self._raw_actions.fill_(0.0) + return + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.action_dim), + inputs=[mask, self._raw_actions], + device=self.device, + ) + + +class JointEffortAction(JointAction): + """Joint action term that applies the processed actions to the articulation's joints as effort commands.""" + + cfg: actions_cfg.JointEffortActionCfg + """The configuration of the action term.""" + + def __init__(self, cfg: actions_cfg.JointEffortActionCfg, env: ManagerBasedEnv): + super().__init__(cfg, env) + + def apply_actions(self): + # set joint effort targets + self._asset.set_joint_effort_target(self.processed_actions, joint_ids=self._joint_ids) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py index 1951f9cc5ae..0a89bced77c 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py @@ -12,6 +12,7 @@ from isaaclab.managers import * # noqa: F401,F403 # Override the stable implementation with the experimental fork. +from .action_manager import ActionManager # noqa: F401 from .manager_term_cfg import RewardTermCfg # noqa: F401 from .reward_manager import RewardManager # noqa: F401 from .scene_entity_cfg import SceneEntityCfg # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py new file mode 100644 index 00000000000..082a00fa0ef --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py @@ -0,0 +1,485 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Action manager for processing actions sent to the environment.""" + +from __future__ import annotations + +import inspect +import re +import weakref +from abc import abstractmethod +from collections.abc import Sequence +from prettytable import PrettyTable +from typing import TYPE_CHECKING, Any + +import warp as wp + +from isaaclab.assets import AssetBase +from isaaclab.envs.utils.io_descriptors import GenericActionIODescriptor + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import ActionTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _zero_masked_2d(mask: wp.array(dtype=wp.bool), data: wp.array(dtype=wp.float32, ndim=2)): + """Zero rows of a 2D buffer where ``mask`` is True. + + Launched with dim = (num_envs, data.shape[1]). + """ + + env_id, j = wp.tid() + if mask[env_id]: + data[env_id, j] = 0.0 + + +class ActionTerm(ManagerTermBase): + """Base class for action terms. + + The action term is responsible for processing the raw actions sent to the environment + and applying them to the asset managed by the term. The action term is comprised of two + operations: + + * Processing of actions: This operation is performed once per **environment step** and + is responsible for pre-processing the raw actions sent to the environment. + * Applying actions: This operation is performed once per **simulation step** and is + responsible for applying the processed actions to the asset managed by the term. + """ + + def __init__(self, cfg: ActionTermCfg, env: ManagerBasedEnv): + """Initialize the action term. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + # call the base class constructor + super().__init__(cfg, env) + # parse config to obtain asset to which the term is applied + self._asset: AssetBase = self._env.scene[self.cfg.asset_name] + self._IO_descriptor = GenericActionIODescriptor() + self._export_IO_descriptor = True + + # add handle for debug visualization (this is set to a valid handle inside set_debug_vis) + self._debug_vis_handle = None + # set initial state of debug visualization + self.set_debug_vis(self.cfg.debug_vis) + + def __del__(self): + """Unsubscribe from the callbacks.""" + if self._debug_vis_handle: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + + """ + Properties. + """ + + @property + @abstractmethod + def action_dim(self) -> int: + """Dimension of the action term.""" + raise NotImplementedError + + @property + @abstractmethod + def raw_actions(self) -> wp.array: + """The input/raw actions sent to the term.""" + raise NotImplementedError + + @property + @abstractmethod + def processed_actions(self) -> wp.array: + """The actions computed by the term after applying any processing.""" + raise NotImplementedError + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the action term has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_debug_vis_impl) + return "NotImplementedError" not in source_code + + @property + def IO_descriptor(self) -> GenericActionIODescriptor: + """The IO descriptor for the action term.""" + self._IO_descriptor.name = re.sub(r"([a-z])([A-Z])", r"\1_\2", self.__class__.__name__).lower() + self._IO_descriptor.full_path = f"{self.__class__.__module__}.{self.__class__.__name__}" + self._IO_descriptor.description = " ".join(self.__class__.__doc__.split()) + self._IO_descriptor.export = self.export_IO_descriptor + return self._IO_descriptor + + @property + def export_IO_descriptor(self) -> bool: + """Whether to export the IO descriptor for the action term.""" + return self._export_IO_descriptor + + """ + Operations. + """ + + def set_debug_vis(self, debug_vis: bool) -> bool: + """Sets whether to visualize the action term data. + Args: + debug_vis: Whether to visualize the action term data. + Returns: + Whether the debug visualization was successfully set. False if the action term does + not support debug visualization. + """ + # check if debug visualization is supported + if not self.has_debug_vis_implementation: + return False + + import omni.kit.app + + # toggle debug visualization objects + self._set_debug_vis_impl(debug_vis) + # toggle debug visualization handles + if debug_vis: + # create a subscriber for the post update event if it doesn't exist + if self._debug_vis_handle is None: + app_interface = omni.kit.app.get_app_interface() + self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( + lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event) + ) + else: + # remove the subscriber if it exists + if self._debug_vis_handle is not None: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + # return success + return True + + @abstractmethod + def process_actions(self, actions: wp.array, action_offset: int = 0): + """Processes the actions sent to the environment. + + Note: + This function is called once per environment step by the manager. + + Args: + actions: The full action buffer of shape (num_envs, total_action_dim). + action_offset: Column offset into the action buffer for this term. + """ + raise NotImplementedError + + @abstractmethod + def apply_actions(self): + """Applies the actions to the asset managed by the term. + + Note: + This is called at every simulation step by the manager. + """ + raise NotImplementedError + + def _set_debug_vis_impl(self, debug_vis: bool): + """Set debug visualization into visualization objects. + This function is responsible for creating the visualization objects if they don't exist + and input ``debug_vis`` is True. If the visualization objects exist, the function should + set their visibility into the stage. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + def _debug_vis_callback(self, event): + """Callback for debug visualization. + This function calls the visualization objects and sets the data to visualize into them. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + +class ActionManager(ManagerBase): + """Manager for processing and applying actions for a given world. + + The action manager handles the interpretation and application of user-defined + actions on a given world. It is comprised of different action terms that decide + the dimension of the expected actions. + + The action manager performs operations at two stages: + + * processing of actions: It splits the input actions to each term and performs any + pre-processing needed. This should be called once at every environment step. + * apply actions: This operation typically sets the processed actions into the assets in the + scene (such as robots). It should be called before every simulation step. + """ + + def __init__(self, cfg: object, env: ManagerBasedEnv): + """Initialize the action manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, ActionTermCfg]``). + env: The environment instance. + + Raises: + ValueError: If the configuration is None. + """ + # check if config is None + if cfg is None: + raise ValueError("Action manager configuration is None. Please provide a valid configuration.") + + # call the base class constructor (this prepares the terms) + super().__init__(cfg, env) + # create buffers to store actions (Warp) + self._action = wp.zeros((self.num_envs, self.total_action_dim), dtype=wp.float32, device=self.device) + self._prev_action = wp.zeros((self.num_envs, self.total_action_dim), dtype=wp.float32, device=self.device) + + # torch views + self._action_torch = wp.to_torch(self._action) + self._prev_action_torch = wp.to_torch(self._prev_action) + + # check if any term has debug visualization implemented + self.cfg.debug_vis = False + for term in self._terms.values(): + self.cfg.debug_vis |= term.cfg.debug_vis + + def __str__(self) -> str: + """Returns: A string representation for action manager.""" + msg = f" contains {len(self._term_names)} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = f"Active Action Terms (shape: {self.total_action_dim})" + table.field_names = ["Index", "Name", "Dimension"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Dimension"] = "r" + # add info on each term + for index, (name, term) in enumerate(self._terms.items()): + table.add_row([index, name, term.action_dim]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def total_action_dim(self) -> int: + """Total dimension of actions.""" + return sum(self.action_term_dim) + + @property + def active_terms(self) -> list[str]: + """Name of active action terms.""" + return self._term_names + + @property + def action_term_dim(self) -> list[int]: + """Shape of each action term.""" + return [term.action_dim for term in self._terms.values()] + + @property + def action(self) -> wp.array: + """The actions sent to the environment. Shape is (num_envs, total_action_dim).""" + return self._action + + @property + def prev_action(self) -> wp.array: + """The previous actions sent to the environment. Shape is (num_envs, total_action_dim).""" + return self._prev_action + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the command terms have debug visualization implemented.""" + # check if function raises NotImplementedError + has_debug_vis = False + for term in self._terms.values(): + has_debug_vis |= term.has_debug_vis_implementation + return has_debug_vis + + @property + def get_IO_descriptors(self) -> list[dict[str, Any]]: + """Get the IO descriptors for the action manager. + + Returns: + A dictionary with keys as the term names and values as the IO descriptors. + """ + + data = [] + + for term_name, term in self._terms.items(): + try: + data.append(term.IO_descriptor.__dict__.copy()) + except Exception as e: + print(f"Error getting IO descriptor for term '{term_name}': {e}") + + formatted_data = [] + for item in data: + name = item.pop("name") + formatted_item = {"name": name, "extras": item.pop("extras")} + print(item["export"]) + if not item.pop("export"): + continue + for k, v in item.items(): + # Check if v is a tuple and convert to list + if isinstance(v, tuple): + v = list(v) + if k in ["description", "units"]: + formatted_item["extras"][k] = v + else: + formatted_item[k] = v + formatted_data.append(formatted_item) + + return formatted_data + + """ + Operations. + """ + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + idx = 0 + # Copy to host for debug/inspection (not on hot path). + for name, term in self._terms.items(): + term_actions = self._action_torch[env_idx, idx : idx + term.action_dim] + terms.append((name, term_actions.tolist())) + idx += term.action_dim + return terms + + def set_debug_vis(self, debug_vis: bool): + """Sets whether to visualize the action data. + Args: + debug_vis: Whether to visualize the action data. + Returns: + Whether the debug visualization was successfully set. False if the action + does not support debug visualization. + """ + for term in self._terms.values(): + term.set_debug_vis(debug_vis) + + def reset(self, mask: wp.array | None = None) -> dict[str, Any]: + """Resets the action history. + + Args: + mask: Boolean mask of shape (num_envs,) indicating which envs to reset. + Defaults to None, in which case all environments are considered. + + Returns: + An empty dictionary. + """ + # reset the action history + if mask is None: + self._prev_action.fill_(0.0) + self._action.fill_(0.0) + else: + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.total_action_dim), + inputs=[mask, self._prev_action], + device=self.device, + ) + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.total_action_dim), + inputs=[mask, self._action], + device=self.device, + ) + + # reset all action terms + for term in self._terms.values(): + term.reset(mask=mask) + # nothing to log here + return {} + + def process_action(self, action: wp.array): + """Processes the actions sent to the environment. + + Note: + This function should be called once per environment step. + + Args: + action: The actions to process. Shape is (num_envs, total_action_dim). + """ + # check if action dimension is valid + if self.total_action_dim != action.shape[1]: + raise ValueError(f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}.") + + # store the input actions + wp.copy(self._prev_action, self._action) + wp.copy(self._action, action) + + # split the actions and apply to each term + idx = 0 + for term in self._terms.values(): + term.process_actions(self._action, idx) + idx += term.action_dim + + def apply_action(self) -> None: + """Applies the actions to the environment/simulation. + + Note: + This should be called at every simulation step. + """ + for term in self._terms.values(): + term.apply_actions() + + def get_term(self, name: str) -> ActionTerm: + """Returns the action term with the specified name. + + Args: + name: The name of the action term. + + Returns: + The action term with the specified name. + """ + return self._terms[name] + + def serialize(self) -> dict: + """Serialize the action manager configuration. + + Returns: + A dictionary of serialized action term configurations. + """ + return {term_name: term.serialize() for term_name, term in self._terms.items()} + + """ + Helper functions. + """ + + def _prepare_terms(self): + # create buffers to parse and store terms + self._term_names: list[str] = list() + self._terms: dict[str, ActionTerm] = dict() + + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # parse action terms from the config + for term_name, term_cfg in cfg_items: + # check if term config is None + if term_cfg is None: + continue + # check valid type + if not isinstance(term_cfg, ActionTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ActionTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # create the action term + term = term_cfg.class_type(term_cfg, self._env) + # sanity check if term is valid type + if not isinstance(term, ActionTerm): + raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.") + # add term name and parameters + self._term_names.append(term_name) + self._terms[term_name] = term diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py new file mode 100644 index 00000000000..0fddc1c016a --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py @@ -0,0 +1,435 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Base classes for managers (experimental). + +This file is a local copy of :mod:`isaaclab.managers.manager_base` placed under +``isaaclab_experimental`` so it can evolve independently for Warp-first / graph-friendly +pipelines. + +Key differences from the stable version: +- :meth:`ManagerTermBase.reset` is **mask-based** (preferred for capture-friendly subset operations). +""" + +from __future__ import annotations + +import contextlib +import copy +import inspect +import logging +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import warp as wp + +import isaaclab.utils.string as string_utils +from isaaclab.utils import class_to_dict, string_to_callable + +from .manager_term_cfg import ManagerTermBaseCfg +from .scene_entity_cfg import SceneEntityCfg + +# import omni.timeline + + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + +# import logger +logger = logging.getLogger(__name__) + + +class ManagerTermBase(ABC): + """Base class for manager terms. + + Manager term implementations can be functions or classes. If the term is a class, it should + inherit from this base class and implement the required methods. + + Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager + class should also have a corresponding configuration class that defines the configuration terms for the + manager. Each term should the :class:`ManagerTermBaseCfg` class or its subclass. + + Example pseudo-code for creating a manager: + + .. code-block:: python + + from isaaclab.utils import configclass + from isaaclab.utils.mdp import ManagerBase, ManagerTermBaseCfg + + @configclass + class MyManagerCfg: + + my_term_1: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + my_term_2: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + my_term_3: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + + # define manager instance + my_manager = ManagerBase(cfg=ManagerCfg(), env=env) + + """ + + def __init__(self, cfg: ManagerTermBaseCfg, env: ManagerBasedEnv): + """Initialize the manager term. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + # store the inputs + self.cfg = cfg + self._env = env + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + def __name__(self) -> str: + """Return the name of the class or subclass.""" + return self.__class__.__name__ + + """ + Operations. + """ + + def reset(self, mask: wp.array | None = None) -> None: + """Resets the manager term (mask-based). + + Args: + mask: Boolean mask of shape (num_envs,) indicating which envs to reset. + If None, all envs are considered. + """ + pass + + def serialize(self) -> dict: + """General serialization call. Includes the configuration dict.""" + return {"cfg": class_to_dict(self.cfg)} + + def __call__(self, *args) -> Any: + """Returns the value of the term required by the manager. + + In case of a class implementation, this function is called by the manager + to get the value of the term. The arguments passed to this function are + the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`). + + .. attention:: + To be consistent with memory-less implementation of terms with functions, it is + recommended to ensure that the returned mutable quantities are cloned before + returning them. For instance, if the term returns a tensor, it is recommended + to ensure that the returned tensor is a clone of the original tensor. This prevents + the manager from storing references to the tensors and altering the original tensors. + + Args: + *args: Variable length argument list. + + Returns: + The value of the term. + """ + raise NotImplementedError("The method '__call__' should be implemented by the subclass.") + + +class ManagerBase(ABC): + """Base class for all managers.""" + + def __init__(self, cfg: object, env: ManagerBasedEnv): + """Initialize the manager. + + This function is responsible for parsing the configuration object and creating the terms. + + If the simulation is not playing, the scene entities are not resolved immediately. + Instead, the resolution is deferred until the simulation starts. This is done to ensure + that the scene entities are resolved even if the manager is created after the simulation + has already started. + + Args: + cfg: The configuration object. If None, the manager is initialized without any terms. + env: The environment instance. + """ + # store the inputs + self.cfg = copy.deepcopy(cfg) + self._env = env + + # flag for whether the scene entities have been resolved + # if sim is playing, we resolve the scene entities directly while preparing the terms + self._is_scene_entities_resolved = self._env.sim.is_playing() + + # if the simulation is not playing, we use callbacks to trigger the resolution of the scene + # entities configuration. this is needed for cases where the manager is created after the simulation + # but before the simulation is playing. + # FIXME: Once Isaac Sim supports storing this information as USD schema, we can remove this + # callback and resolve the scene entities directly inside `_prepare_terms`. + # if not self._env.sim.is_playing(): + # # note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor + # # is called + # # The order is set to 20 to allow asset/sensor initialization to complete before the scene entities + # # are resolved. Those have the order 10. + # timeline_event_stream = omni.timeline.get_timeline_interface().get_timeline_event_stream() + # self._resolve_terms_handle = timeline_event_stream.create_subscription_to_pop_by_type( + # int(omni.timeline.TimelineEventType.PLAY), + # lambda event, obj=weakref.proxy(self): obj._resolve_terms_callback(event), + # order=20, + # ) + # else: + # self._resolve_terms_handle = None + self._resolve_terms_handle = None + + # parse config to create terms information + if self.cfg: + self._prepare_terms() + + def __del__(self): + """Delete the manager.""" + # Suppress errors during Python shutdown + # Note: contextlib may be None during interpreter shutdown + if contextlib is not None: + with contextlib.suppress(ImportError, AttributeError, TypeError): + if getattr(self, "_resolve_terms_handle", None): + self._resolve_terms_handle.unsubscribe() + self._resolve_terms_handle = None + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + @abstractmethod + def active_terms(self) -> list[str] | dict[str, list[str]]: + """Name of active terms.""" + raise NotImplementedError + + """ + Operations. + """ + + def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: + """Resets the manager and returns logging information for the current time-step. + + Args: + env_ids: The environment ids for which to log data. + Defaults None, which logs data for all environments. + + Returns: + Dictionary containing the logging information. + """ + return {} + + def find_terms(self, name_keys: str | Sequence[str]) -> list[str]: + """Find terms in the manager based on the names. + + This function searches the manager for terms based on the names. The names can be + specified as regular expressions or a list of regular expressions. The search is + performed on the active terms in the manager. + + Please check the :meth:`~isaaclab.utils.string_utils.resolve_matching_names` function for more + information on the name matching. + + Args: + name_keys: A regular expression or a list of regular expressions to match the term names. + + Returns: + A list of term names that match the input keys. + """ + # resolve search keys + if isinstance(self.active_terms, dict): + list_of_strings = [] + for names in self.active_terms.values(): + list_of_strings.extend(names) + else: + list_of_strings = self.active_terms + + # return the matching names + return string_utils.resolve_matching_names(name_keys, list_of_strings)[1] + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Returns: + The active terms. + """ + raise NotImplementedError + + """ + Implementation specific. + """ + + @abstractmethod + def _prepare_terms(self): + """Prepare terms information from the configuration object.""" + raise NotImplementedError + + """ + Internal callbacks. + """ + + def _resolve_terms_callback(self, event): + """Resolve configurations of terms once the simulation starts. + + Please check the :meth:`_process_term_cfg_at_play` method for more information. + """ + # check if scene entities have been resolved + if self._is_scene_entities_resolved: + return + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + self._process_term_cfg_at_play(term_name, term_cfg) + + # set the flag + self._is_scene_entities_resolved = True + + """ + Internal functions. + """ + + def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1): + """Resolve common attributes of the term configuration. + + Usually, called by the :meth:`_prepare_terms` method to resolve common attributes of the term + configuration. These include: + + * Resolving the term function and checking if it is callable. + * Checking if the term function's arguments are matched by the parameters. + * Resolving special attributes of the term configuration like ``asset_cfg``, ``sensor_cfg``, etc. + * Initializing the term if it is a class. + + The last two steps are only possible once the simulation starts playing. + + By default, all term functions are expected to have at least one argument, which is the + environment object. Some other managers may expect functions to take more arguments, for + instance, the environment indices as the second argument. In such cases, the + ``min_argc`` argument can be used to specify the minimum number of arguments + required by the term function to be called correctly by the manager. + + Args: + term_name: The name of the term. + term_cfg: The term configuration. + min_argc: The minimum number of arguments required by the term function to be called correctly + by the manager. + + Raises: + TypeError: If the term configuration is not of type :class:`ManagerTermBaseCfg`. + ValueError: If the scene entity defined in the term configuration does not exist. + AttributeError: If the term function is not callable. + ValueError: If the term function's arguments are not matched by the parameters. + """ + # check if the term is a valid term config + if not isinstance(term_cfg, ManagerTermBaseCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ManagerTermBaseCfg." + f" Received: '{type(term_cfg)}'." + ) + + # get the corresponding function or functional class + if isinstance(term_cfg.func, str): + term_cfg.func = string_to_callable(term_cfg.func) + # check if function is callable + if not callable(term_cfg.func): + raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") + + # check if the term is a class of valid type + if inspect.isclass(term_cfg.func): + if not issubclass(term_cfg.func, ManagerTermBase): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ManagerTermBase." + f" Received: '{type(term_cfg.func)}'." + ) + func_static = term_cfg.func.__call__ + min_argc += 1 # forward by 1 to account for 'self' argument + else: + func_static = term_cfg.func + # check if function is callable + if not callable(func_static): + raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") + + # check statically if the term's arguments are matched by params + term_params = list(term_cfg.params.keys()) + args = inspect.signature(func_static).parameters + args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty] + args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty] + args = args_without_defaults + args_with_defaults + # ignore first two arguments for env and env_ids + # Think: Check for cases when kwargs are set inside the function? + if len(args) > min_argc: + if set(args[min_argc:]) != set(term_params + args_with_defaults): + raise ValueError( + f"The term '{term_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}" + f" and optional parameters: {args_with_defaults}, but received: {term_params}." + ) + + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + if self._env.sim.is_playing(): + self._process_term_cfg_at_play(term_name, term_cfg) + + def _process_term_cfg_at_play(self, term_name: str, term_cfg: ManagerTermBaseCfg): + """Process the term configuration at runtime. + + This function is called when the simulation starts playing. It is used to process the term + configuration at runtime. This includes: + + * Resolving the scene entity configuration for the term. + * Initializing the term if it is a class. + + Since the above steps rely on PhysX to parse over the simulation scene, they are deferred + until the simulation starts playing. + + Args: + term_name: The name of the term. + term_cfg: The term configuration. + """ + for key, value in term_cfg.params.items(): + if isinstance(value, SceneEntityCfg): + # load the entity + try: + value.resolve(self._env.scene) + except ValueError as e: + raise ValueError(f"Error while parsing '{term_name}:{key}'. {e}") + # log the entity for checking later + msg = f"[{term_cfg.__class__.__name__}:{term_name}] Found entity '{value.name}'." + if value.joint_ids is not None: + msg += f"\n\tJoint names: {value.joint_names} [{value.joint_ids}]" + if value.body_ids is not None: + msg += f"\n\tBody names: {value.body_names} [{value.body_ids}]" + # print the information + logger.info(msg) + # store the entity + term_cfg.params[key] = value + + # initialize the term if it is a class + if inspect.isclass(term_cfg.func): + logger.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.") + term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env) From 6cd4b3215e7de6b7f2db5bc5b19cb4a3874c37e6 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Fri, 30 Jan 2026 02:14:52 -0800 Subject: [PATCH 8/8] Added ScopedCapture for reward and action manager --- .../envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md | 11 ++ .../envs/manager_based_rl_env_warp.py | 62 +++++++++- .../envs/mdp/actions/joint_actions.py | 10 +- .../assets/articulation/articulation.py | 113 +++++++----------- .../assets/articulation/articulation_data.py | 99 +++++++++++++++ 5 files changed, 218 insertions(+), 77 deletions(-) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md b/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md index 4023902c828..275efe26f84 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/MANAGER_BASED_WARP_MIGRATION_PLAN.md @@ -83,6 +83,17 @@ Notes (Torch → Warp porting): `(action_dim, 2)` for `clip`) and index by `j` inside kernels. This avoids redundant per-env storage and extra broadcast kernels, while preserving behavior. +Notes (Warp CUDA graph capture in manager-based env): +- Partition the env step into **small stage functions** that only touch persistent CUDA buffers, then capture/replay them with Warp: + - `step_warp_action_process(...)`: `ActionManager.process_action` (env-step) + - `step_warp_action_apply(...)`: `ActionManager.apply_action` + `scene.write_data_to_sim` (sim-step) + - `step_warp_reward_compute(dt)`: `RewardManager.compute(dt)` (env-step) +- Use a helper like `capture_or_launch(fn, *args, **kwargs)` keyed by `fn.__name__` to standardize: + “if first time: `wp.ScopedCapture()`; else: `wp.capture_launch(graph)`”. +- Any captured stage that reads inputs must read from **stable pointers**: + e.g. keep a persistent `wp.array` action input buffer and copy incoming actions into it each step. +- If the launch topology changes (term list / shapes / enabling debug-vis, etc.), invalidate cached graphs and recapture. + ### Phase 3 — Dependency surfacing and hybrid handling What: diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 82db14efa77..7ba7f38d6af 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -92,6 +92,14 @@ def __init__(self, cfg: ManagerBasedRLEnvCfg, render_mode: str | None = None, ** # store the render mode self.render_mode = render_mode + # -- warp graph capture + # These graphs are cached by function name (see `_wp_capture_or_launch`). + self._wp_graphs: dict[str, Any] = {} + # Persistent action input buffer to keep pointer stable for captured graphs. + self._action_in_wp: wp.array = wp.zeros( + (self.num_envs, self.action_manager.total_action_dim), dtype=wp.float32, device=self.device + ) + # initialize data and constants # -- set the framerate of the gym video recorder wrapper so that the playback speed of the produced video matches the simulation self.metadata["render_fps"] = 1 / self.step_dt @@ -160,6 +168,45 @@ def setup_manager_visualizers(self): Operations - MDP """ + def invalidate_wp_graphs(self) -> None: + """Invalidate all cached Warp graphs. + + Call this if the captured launch topology changes (e.g. different term list, shapes, etc.). + """ + self._wp_graphs.clear() + + def _wp_capture_or_launch(self, fn, *args, **kwargs): + """Capture a Warp CUDA graph for ``fn`` on first call, then replay it. + + Notes: + - The cache key is ``fn.__name__`` (intended for bound methods). + - The function body is *not* executed on replay; only the captured CUDA graph is launched. + Therefore, only capture functions that operate purely via persistent CUDA buffers. + """ + print("[INFO] Running captured graph for function: ", fn.__name__) + fn_name = fn.__name__ + graph = self._wp_graphs.get(fn_name) + if graph is None: + with wp.ScopedCapture() as capture: + fn(*args, **kwargs) + self._wp_graphs[fn_name] = capture.graph + else: + wp.capture_launch(graph) + + def step_warp_action_process(self) -> None: + """Captured stage: process actions (env-step frequency).""" + assert self._action_in_wp is not None + self.action_manager.process_action(self._action_in_wp) + + def step_warp_action_apply(self) -> None: + """Captured stage: apply actions + write to sim (sim-step frequency).""" + self.action_manager.apply_action() + self.scene.write_data_to_sim() + + def step_warp_reward_compute(self, dt: float) -> None: + """Captured stage: compute rewards (env-step frequency).""" + self.reward_buf = self.reward_manager.compute(dt=dt) + @Timer(name="env_step", msg="Step took:", enable=True, format="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics and reset terminated environments. @@ -181,14 +228,18 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: A tuple containing the observations, rewards, resets (terminated and truncated) and extras. """ # process actions - action_wp = wp.from_torch(action) + # NOTE: keep a persistent action input buffer for graph pointer stability + + action_device = action.to(self.device) + self._action_in_wp = wp.from_torch(action_device, dtype=wp.float32) + with Timer( name="action_manager.process_action", msg="ActionManager.process_action took:", enable=True, format="us", ): - self.action_manager.process_action(action_wp) + self._wp_capture_or_launch(self.step_warp_action_process) self.recorder_manager.record_pre_step() @@ -206,9 +257,8 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: enable=True, format="us", ): - self.action_manager.apply_action() - # set actions into simulator - self.scene.write_data_to_sim() + self._wp_capture_or_launch(self.step_warp_action_apply) + # simulate with Timer(name="simulate", msg="Newton simulation step took:", enable=True, format="us"): self.sim.step(render=False) @@ -244,7 +294,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: enable=True, format="us", ): - self.reward_buf = self.reward_manager.compute(dt=self.step_dt) + self._wp_capture_or_launch(self.step_warp_reward_compute, float(self.step_dt)) if len(self.recorder_manager.active_terms) > 0: # update observations for recording if needed diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py index e4e8a039eff..19823acf69c 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py @@ -102,6 +102,8 @@ class JointAction(ActionTerm): """The offset applied to the input action.""" _clip: wp.array """The clip applied to the input action.""" + _joint_mask: wp.array + """A persistent joint mask for capturable action application.""" def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None: # initialize the action term @@ -122,6 +124,11 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non if self._num_joints == self._asset.num_joints and not self.cfg.preserve_order: self._joint_ids = slice(None) + # Pre-compute a per-term (non-shared) joint mask. + # NOTE: ArticulationData uses a shared scratch mask for ids->mask resolution, so we must clone the + # resolved mask to keep it stable when multiple action terms control different joint subsets. + self._joint_mask = wp.clone(self._asset.data.resolve_joint_mask(joint_ids=self._joint_ids)) + # create tensors for raw and processed actions (Warp) self._raw_actions = wp.zeros((self.num_envs, self.action_dim), dtype=wp.float32, device=self.device) self._processed_actions = wp.zeros_like(self.raw_actions) @@ -167,6 +174,7 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non clip_low_vec = wp.array(clip_low, dtype=wp.float32, device=self.device) clip_high_vec = wp.array(clip_high, dtype=wp.float32, device=self.device) self._clip = wp.zeros((self.action_dim, 2), dtype=wp.float32, device=self.device) + # TODO(jichuanh): use np.stack([a, b], axis=0) wp.launch( kernel=_set_clip_1d_to_2d, dim=self.action_dim, @@ -269,4 +277,4 @@ def __init__(self, cfg: actions_cfg.JointEffortActionCfg, env: ManagerBasedEnv): def apply_actions(self): # set joint effort targets - self._asset.set_joint_effort_target(self.processed_actions, joint_ids=self._joint_ids) + self._asset.set_joint_effort_target(self.processed_actions, joint_mask=self._joint_mask) diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py index fff003b43b5..0a1dce972bd 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py @@ -49,7 +49,6 @@ from isaaclab.utils.warp.utils import ( make_complete_data_from_torch_dual_index, make_complete_data_from_torch_single_index, - make_masks_from_torch_ids, ) if TYPE_CHECKING: @@ -380,10 +379,7 @@ def write_root_state_to_sim( root_state = make_complete_data_from_torch_single_index( root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # split the state into pose and velocity pose, velocity = self._split_state(root_state) # write the pose and velocity to the simulation @@ -412,9 +408,7 @@ def write_root_com_state_to_sim( root_state = make_complete_data_from_torch_single_index( root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # split the state into pose and velocity pose, velocity = self._split_state(root_state) # write the pose and velocity to the simulation @@ -443,9 +437,7 @@ def write_root_link_state_to_sim( root_state = make_complete_data_from_torch_single_index( root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # split the state into pose and velocity pose, velocity = self._split_state(root_state) # write the pose and velocity to the simulation @@ -490,10 +482,7 @@ def write_root_link_pose_to_sim( pose = make_complete_data_from_torch_single_index( pose, self.num_instances, ids=env_ids, dtype=wp.transformf, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # set into simulation self._update_array_with_array_masked(pose, self._data.root_link_pose_w, env_mask, self.num_instances) # invalidate the root com pose @@ -520,10 +509,7 @@ def write_root_com_pose_to_sim( root_pose = make_complete_data_from_torch_single_index( root_pose, self.num_instances, ids=env_ids, dtype=wp.transformf, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # Write to Newton using warp self._update_array_with_array_masked(root_pose, self._data._root_com_pose_w.data, env_mask, self.num_instances) # set link frame poses @@ -580,10 +566,7 @@ def write_root_com_velocity_to_sim( root_velocity = make_complete_data_from_torch_single_index( root_velocity, self.num_instances, ids=env_ids, dtype=wp.spatial_vectorf, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # set into simulation self._update_array_with_array_masked(root_velocity, self._data.root_com_vel_w, env_mask, self.num_instances) # invalidate the derived velocities @@ -609,10 +592,7 @@ def write_root_link_velocity_to_sim( root_velocity = make_complete_data_from_torch_single_index( root_velocity, self.num_instances, ids=env_ids, dtype=wp.spatial_vectorf, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) # update the root link velocity self._update_array_with_array_masked( root_velocity, self._data._root_link_vel_w.data, env_mask, self.num_instances @@ -675,8 +655,8 @@ def write_joint_state_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( @@ -714,8 +694,8 @@ def write_joint_position_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( @@ -750,8 +730,8 @@ def write_joint_velocity_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( @@ -790,8 +770,8 @@ def write_joint_stiffness_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(stiffness, float): @@ -833,8 +813,8 @@ def write_joint_damping_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(damping, float): @@ -888,8 +868,8 @@ def write_joint_position_limit_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation self._write_joint_position_limit_to_sim(lower_limits, upper_limits, joint_mask, env_mask) @@ -933,8 +913,8 @@ def write_joint_velocity_limit_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(limits, float): @@ -979,8 +959,8 @@ def write_joint_effort_limit_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(limits, float): @@ -1025,8 +1005,8 @@ def write_joint_armature_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(armature, float): @@ -1077,8 +1057,8 @@ def write_joint_friction_coefficient_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(joint_friction_coeff, float): @@ -1119,8 +1099,8 @@ def write_joint_dynamic_friction_coefficient_to_sim( dtype=wp.float32, device=self.device, ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # None masks are handled within the kernel. # set into simulation if isinstance(joint_dynamic_friction_coeff, float): @@ -1219,8 +1199,8 @@ def set_masses( masses = make_complete_data_from_torch_dual_index( masses, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.float32, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + body_mask = self._data.resolve_body_mask(body_ids=body_ids, body_mask=body_mask) # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( masses, self._data.body_mass, env_mask, body_mask, (self.num_instances, self.num_bodies) @@ -1248,8 +1228,8 @@ def set_coms( coms = make_complete_data_from_torch_dual_index( coms, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.vec3f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + body_mask = self._data.resolve_body_mask(body_ids=body_ids, body_mask=body_mask) # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( coms, self._data.body_com_pos_b, env_mask, body_mask, (self.num_instances, self.num_bodies) @@ -1277,8 +1257,8 @@ def set_inertias( inertias = make_complete_data_from_torch_dual_index( inertias, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.mat33f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + body_mask = self._data.resolve_body_mask(body_ids=body_ids, body_mask=body_mask) # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( inertias, self._data.body_inertia, env_mask, body_mask, (self.num_instances, self.num_bodies) @@ -1330,8 +1310,6 @@ def set_external_force_and_torque( the external wrench is applied in the link frame of the articulations' bodies. """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. - env_mask_ = None - body_mask_ = None if isinstance(forces, torch.Tensor) or isinstance(torques, torch.Tensor): if forces is not None: forces = make_complete_data_from_torch_dual_index( @@ -1341,13 +1319,8 @@ def set_external_force_and_torque( torques = make_complete_data_from_torch_dual_index( torques, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.vec3f, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) - # solve for None masks - if env_mask_ is None: - env_mask_ = self._data.ALL_ENV_MASK - if body_mask_ is None: - body_mask_ = self._data.ALL_BODY_MASK + env_mask_ = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + body_mask_ = self._data.resolve_body_mask(body_ids=body_ids, body_mask=body_mask) # set into simulation if (forces is not None) or (torques is not None): self.has_external_wrench = True @@ -1399,8 +1372,8 @@ def set_joint_position_target( target = make_complete_data_from_torch_dual_index( target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32 ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # set into the actuator target buffer wp.launch( update_array2D_with_array2D_masked, @@ -1438,8 +1411,8 @@ def set_joint_velocity_target( target = make_complete_data_from_torch_dual_index( target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # set into the actuator target buffer self._update_batched_array_with_batched_array_masked( target, self._data.actuator_velocity_target, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -1470,8 +1443,8 @@ def set_joint_effort_target( target = make_complete_data_from_torch_dual_index( target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32, device=self.device ) - env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) - joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + env_mask = self._data.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + joint_mask = self._data.resolve_joint_mask(joint_ids=joint_ids, joint_mask=joint_mask) # set into the actuator effort target buffer self._update_batched_array_with_batched_array_masked( target, self._data.actuator_effort_target, env_mask, joint_mask, (self.num_instances, self.num_joints) diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py index b3dae156c1d..f91c822c3f2 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py @@ -7,6 +7,7 @@ import logging import torch import weakref +from collections.abc import Sequence import warp as wp from isaaclab_newton.kernels import ( @@ -17,6 +18,7 @@ compute_heading, derive_body_acceleration_from_velocity_batched, derive_joint_acceleration_from_velocity, + generate_mask_from_ids, generate_pose_from_position_with_unit_quaternion_batched, make_joint_pos_limits_from_lower_and_upper_limits, project_com_velocity_to_link_frame_batch, @@ -123,6 +125,103 @@ def is_primed(self, value: bool): raise RuntimeError("Cannot set is_primed after instantiation.") self._is_primed = value + ## + # Mask resolvers (ids -> wp.bool mask). + ## + + def _resolve_1d_mask( + self, + *, + ids: Sequence[int] | slice | wp.array | torch.Tensor | None, + mask: wp.array | torch.Tensor | None, + all_mask: wp.array, + scratch_mask: wp.array, + ) -> wp.array: + """Resolve ids/mask into a warp boolean mask. + + Notes: + - Returns ``all_mask`` when both ids and mask are None (or ids is slice(None)). + - If ids are provided and mask is None, this populates ``scratch_mask`` in-place using Warp kernels. + - Torch inputs are supported for compatibility, but are generally not CUDA-graph-capture friendly. + """ + # Fast path: explicit mask provided. + if mask is not None: + if isinstance(mask, torch.Tensor): + # Ensure boolean + correct device, then wrap for Warp. + if mask.dtype != torch.bool: + mask = mask.to(dtype=torch.bool) + if str(mask.device) != self.device: + mask = mask.to(self.device) + return wp.from_torch(mask, dtype=wp.bool) + return mask + + # Fast path: ids == all / not specified. + if ids is None: + return all_mask + if isinstance(ids, slice) and ids == slice(None): + return all_mask + + # Normalize ids into a 1D wp.int32 array. + if isinstance(ids, slice): + # Convert to explicit indices (supports partial slices). + # We infer the valid range from the scratch mask length. + start, stop, step = ids.indices(scratch_mask.shape[0]) + ids = list(range(start, stop, step)) + + if isinstance(ids, wp.array): + ids_wp = ids + elif isinstance(ids, torch.Tensor): + if ids.dtype != torch.int32: + ids = ids.to(dtype=torch.int32) + if str(ids.device) != self.device: + ids = ids.to(self.device) + ids_wp = wp.from_torch(ids, dtype=wp.int32) + else: + ids_list = list(ids) + ids_wp = wp.array(ids_list, dtype=wp.int32, device=self.device) if len(ids_list) > 0 else None + + # Populate scratch mask. + scratch_mask.fill_(False) + if ids_wp is not None: + wp.launch( + kernel=generate_mask_from_ids, + dim=ids_wp.shape[0], + inputs=[scratch_mask, ids_wp], + device=self.device, + ) + return scratch_mask + + def resolve_env_mask( + self, + *, + env_ids: Sequence[int] | slice | wp.array | torch.Tensor | None = None, + env_mask: wp.array | torch.Tensor | None = None, + ) -> wp.array: + """Resolve environment ids/mask into a warp boolean mask of shape (num_instances,).""" + return self._resolve_1d_mask(ids=env_ids, mask=env_mask, all_mask=self.ALL_ENV_MASK, scratch_mask=self.ENV_MASK) + + def resolve_body_mask( + self, + *, + body_ids: Sequence[int] | slice | wp.array | torch.Tensor | None = None, + body_mask: wp.array | torch.Tensor | None = None, + ) -> wp.array: + """Resolve body ids/mask into a warp boolean mask of shape (num_bodies,).""" + return self._resolve_1d_mask( + ids=body_ids, mask=body_mask, all_mask=self.ALL_BODY_MASK, scratch_mask=self.BODY_MASK + ) + + def resolve_joint_mask( + self, + *, + joint_ids: Sequence[int] | slice | wp.array | torch.Tensor | None = None, + joint_mask: wp.array | torch.Tensor | None = None, + ) -> wp.array: + """Resolve joint ids/mask into a warp boolean mask of shape (num_joints,).""" + return self._resolve_1d_mask( + ids=joint_ids, mask=joint_mask, all_mask=self.ALL_JOINT_MASK, scratch_mask=self.JOINT_MASK + ) + ## # Names. ##