Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def clear_tm_instances():
_, uncompiled_shared_network = make_untrained_iqn_network(jit=config_copy.use_jit, is_inference=False)
uncompiled_shared_network.share_memory()

# init random number generator
seed = 275328254363729247691611008422666101254
# creating the RNG that is passed around. spawn() will create new independent child generators from it
rng = np.random.default_rng(seed)

# Start learner process
learner_process = mp.Process(
target=learner_process_fn,
Expand All @@ -129,6 +134,7 @@ def clear_tm_instances():
base_dir,
save_dir,
tensorboard_base_dir,
rng.spawn(1)[0],
),
)
learner_process.start()
Expand All @@ -148,6 +154,7 @@ def clear_tm_instances():
base_dir,
save_dir,
config_copy.base_tmi_port + process_number,
rng.spawn(1)[0],
),
)
for rollout_queue, process_number in zip(rollout_queues, range(config_copy.gpu_collectors_count))
Expand Down
9 changes: 6 additions & 3 deletions trackmania_rl/agents/iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- The Trainer class, which implements the IQN training logic in method train_on_batch.
- The Inferer class, which implements utilities for forward propagation with and without exploration.
"""

import copy
import math
import random
Expand Down Expand Up @@ -350,15 +351,17 @@ class Inferer:
"epsilon_boltzmann",
"tau_epsilon_boltzmann",
"is_explo",
"_rng",
)

def __init__(self, inference_network, iqn_k, tau_epsilon_boltzmann):
def __init__(self, inference_network, iqn_k, tau_epsilon_boltzmann, rng: np.random.Generator):
self.inference_network = inference_network
self.iqn_k = iqn_k
self.epsilon = None
self.epsilon_boltzmann = None
self.tau_epsilon_boltzmann = tau_epsilon_boltzmann
self.is_explo = None
self._rng = rng

def infer_network(self, img_inputs_uint8: npt.NDArray, float_inputs: npt.NDArray, tau=None) -> npt.NDArray:
"""
Expand Down Expand Up @@ -415,9 +418,9 @@ def get_exploration_action(self, img_inputs_uint8: npt.NDArray, float_inputs: np

if self.is_explo and r < self.epsilon:
# Choose a random action
get_argmax_on = np.random.randn(*q_values.shape)
get_argmax_on = self._rng.standard_normal(*q_values.shape)
elif self.is_explo and r < self.epsilon + self.epsilon_boltzmann:
get_argmax_on = q_values + self.tau_epsilon_boltzmann * np.random.randn(*q_values.shape)
get_argmax_on = q_values + self.tau_epsilon_boltzmann * self._rng.standard_normal(*q_values.shape)
else:
get_argmax_on = q_values

Expand Down
Loading