Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cf326f8
Update quadx_ball_in_cup_env.py
NishantChandna1403 Feb 24, 2025
d9c82ae
Update quadx_pole_balance_env.py
NishantChandna1403 Feb 24, 2025
2fa5197
Update quadx_hover_env.py
NishantChandna1403 Feb 24, 2025
a7eca1e
Update quadx_pole_waypoints_env.py
NishantChandna1403 Feb 24, 2025
17fd3c4
Update quadx_waypoints_env.py
NishantChandna1403 Feb 24, 2025
0359cb4
Merge branch 'jjshoots:master' into master
NishantChandna1403 Feb 28, 2025
7c62f8b
Update __init__.py
NishantChandna1403 Feb 28, 2025
8a3a1be
Update quadx_pole_balance_env.py
NishantChandna1403 Feb 28, 2025
f070f13
Update quadx_waypoints_env.py
NishantChandna1403 Feb 28, 2025
9fc29e7
Update quadx_pole_waypoints_env.py
NishantChandna1403 Feb 28, 2025
587c36c
Update quadx_pole_balance_env.py
NishantChandna1403 Feb 28, 2025
b8e7081
Update quadx_hover_env.py
NishantChandna1403 Feb 28, 2025
63b2f56
Update quadx_ball_in_cup_env.py
NishantChandna1403 Feb 28, 2025
0631720
Update test_gym_envs.py
NishantChandna1403 Feb 28, 2025
3093415
Update quadx_ball_in_cup_env.py
NishantChandna1403 Feb 28, 2025
7024da6
Update quadx_hover_env.py
NishantChandna1403 Feb 28, 2025
008d7e4
Update quadx_pole_balance_env.py
NishantChandna1403 Feb 28, 2025
57b5147
Update quadx_pole_waypoints_env.py
NishantChandna1403 Feb 28, 2025
9ac57e7
Update quadx_waypoints_env.py
NishantChandna1403 Feb 28, 2025
fb0976a
Update quadx_waypoints_env.py
NishantChandna1403 Feb 28, 2025
c072a2e
Update quadx_pole_waypoints_env.py
NishantChandna1403 Feb 28, 2025
46377fc
chore: apply pre-commit fixes
NishantChandna1403 Mar 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions PyFlyt/gym_envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@

# QuadX Envs
register(
id="PyFlyt/QuadX-Hover-v3",
id="PyFlyt/QuadX-Hover-v4",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_hover_env:QuadXHoverEnv",
)
register(
id="PyFlyt/QuadX-Waypoints-v3",
id="PyFlyt/QuadX-Waypoints-v4",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_waypoints_env:QuadXWaypointsEnv",
)
register(
id="PyFlyt/QuadX-Gates-v3",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_gates_env:QuadXGatesEnv",
)
register(
id="PyFlyt/QuadX-Pole-Balance-v3",
id="PyFlyt/QuadX-Pole-Balance-v4",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_pole_balance_env:QuadXPoleBalanceEnv",
)
register(
id="PyFlyt/QuadX-Pole-Waypoints-v3",
id="PyFlyt/QuadX-Pole-Waypoints-v4",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_pole_waypoints_env:QuadXPoleWaypointsEnv",
)
register(
id="PyFlyt/QuadX-Ball-In-Cup-v3",
id="PyFlyt/QuadX-Ball-In-Cup-v4",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_ball_in_cup_env:QuadXBallInCupEnv",
)

Expand Down
10 changes: 9 additions & 1 deletion PyFlyt/gym_envs/quadx_envs/quadx_ball_in_cup_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def compute_state(self) -> None:
def compute_term_trunc_reward(self) -> None:
"""Computes the termination, truncation, and reward of the current timestep."""
super().compute_base_term_trunc_reward()

# compute some parameters of the ball
# lin_pos: [3,], height: [1,], abs_dist: [1,]
ball_rel_lin_pos = self.ball_lin_pos - self.env.state(0)[-1]
Expand All @@ -264,6 +263,15 @@ def compute_term_trunc_reward(self) -> None:

# bonus reward if we are not sparse
if not self.sparse_reward:
# Negative Reward For High Yaw rate, To prevent high yaw while training
yaw_rate = abs(
self.env.state(0)[0][2]
) # Assuming z-axis is the last component
yaw_rate_penalty = 0.01 * yaw_rate**2 # Add penalty for high yaw rate
self.reward -= (
yaw_rate_penalty # You can adjust the coefficient (0.01) as needed
)

# reward for staying alive
self.reward += 0.4

Expand Down
9 changes: 8 additions & 1 deletion PyFlyt/gym_envs/quadx_envs/quadx_hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,19 @@ def compute_state(self) -> None:
def compute_term_trunc_reward(self) -> None:
"""Computes the termination, truncation, and reward of the current timestep."""
super().compute_base_term_trunc_reward()

if not self.sparse_reward:
# distance from 0, 0, 1 hover point
linear_distance = np.linalg.norm(
self.env.state(0)[-1] - np.array([0.0, 0.0, 1.0])
)
# Negative Reward For High Yaw rate, To prevent high yaw while training
yaw_rate = abs(
self.env.state(0)[0][2]
) # Assuming z-axis is the last component
yaw_rate_penalty = 0.01 * yaw_rate**2 # Add penalty for high yaw rate
self.reward -= (
yaw_rate_penalty # You can adjust the coefficient (0.01) as needed
)

# how far are we from 0 roll pitch
angular_distance = np.linalg.norm(self.env.state(0)[1][:2])
Expand Down
10 changes: 9 additions & 1 deletion PyFlyt/gym_envs/quadx_envs/quadx_pole_balance_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,15 @@ def compute_term_trunc_reward(self) -> None:

# how far are we from 0 roll pitch
angular_distance = np.linalg.norm(self.env.state(0)[1][:2])

self.reward -= linear_distance + angular_distance
self.reward -= self.pole.leaningness
self.reward += 1.0

# Negative Reward For High Yaw rate, To prevent high yaw while training
yaw_rate = abs(
self.env.state(0)[0][2]
) # Assuming z-axis is the last component
yaw_rate_penalty = 0.01 * yaw_rate**2 # Add penalty for high yaw rate
self.reward -= (
yaw_rate_penalty # You can adjust the coefficient (0.01) as needed
)
9 changes: 8 additions & 1 deletion PyFlyt/gym_envs/quadx_envs/quadx_pole_waypoints_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,19 @@ def compute_state(self) -> None:
def compute_term_trunc_reward(self) -> None:
"""Computes the termination, truncation, and reward of the current timestep."""
super().compute_base_term_trunc_reward()

# bonus reward if we are not sparse
if not self.sparse_reward:
self.reward += max(15.0 * self.waypoints.progress_to_next_target, 0.0)
self.reward += 0.5 / self.waypoints.distance_to_next_target
self.reward += 0.5 - self.pole.leaningness
# Negative Reward For High Yaw rate, To prevent high yaw while training
yaw_rate = abs(
self.env.state(0)[0][2]
) # Assuming z-axis is the last component
yaw_rate_penalty = 0.01 * yaw_rate**2 # Add penalty for high yaw rate
self.reward -= (
yaw_rate_penalty # You can adjust the coefficient (0.01) as needed
)

# target reached
if self.waypoints.target_reached:
Expand Down
8 changes: 8 additions & 0 deletions PyFlyt/gym_envs/quadx_envs/quadx_waypoints_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def compute_term_trunc_reward(self) -> None:
if not self.sparse_reward:
self.reward += max(3.0 * self.waypoints.progress_to_next_target, 0.0)
self.reward += 0.1 / self.waypoints.distance_to_next_target
# Negative Reward For High Yaw rate, To prevent high yaw while training
yaw_rate = abs(
self.env.state(0)[0][2]
) # Assuming z-axis is the last component
yaw_rate_penalty = 0.01 * yaw_rate**2 # Add penalty for high yaw rate
self.reward -= (
yaw_rate_penalty # You can adjust the coefficient (0.01) as needed
)

# target reached
if self.waypoints.target_reached:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_gym_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
_WAYPOINT_ENV_CONFIGS = []
for env_name, angle_representation, sparse_reward in itertools.product(
[
"PyFlyt/QuadX-Waypoints-v3",
"PyFlyt/QuadX-Pole-Waypoints-v3",
"PyFlyt/QuadX-Waypoints-v4",
"PyFlyt/QuadX-Pole-Waypoints-v4",
"PyFlyt/Fixedwing-Waypoints-v3",
],
["euler", "quaternion"],
Expand All @@ -37,9 +37,9 @@
_NORMAL_ENV_CONFIGS = []
for env_name, angle_representation, sparse_reward in itertools.product(
[
"PyFlyt/QuadX-Hover-v3",
"PyFlyt/QuadX-Pole-Balance-v3",
"PyFlyt/QuadX-Ball-In-Cup-v3",
"PyFlyt/QuadX-Hover-v4",
"PyFlyt/QuadX-Pole-Balance-v4",
"PyFlyt/QuadX-Ball-In-Cup-v4",
"PyFlyt/Rocket-Landing-v4",
],
["euler", "quaternion"],
Expand Down
Loading