Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ Follow these steps to setup locally before you run the `notebooks/SAC_Demo.ipynb

2. Ensure you have `protoc` and `ffmpeg` installed, as well as `python >=3.10.12 and <3.12`. You can install these running `sudo apt install -y protobuf-compiler` and `sudo apt install -y ffmpeg`

3. Create a virtual environment by running `python -m venv .venv`. Then, install poetry with `pip install poetry`
3. Create a virtual environment by running `python -m venv .venv` and then activate it by running `source .venv/bin/activate`.

4. Install the dependencies by running `poetry install`
4. Then, install poetry with `pip install poetry`

5. Build the `.proto` files at `smart_control/proto`into python files by running `cd smart_control/proto && protoc --python_out=. smart_control_building.proto smart_control_normalization.proto smart_control_reward.proto && cd ../..`
5. Install the dependencies by running `poetry install`

6. Modify the value of `VIDEO_PATH_ROOT` at `smart_control/simulator/constants.py`. This is the path where simulation videos will be stored
6. Build the `.proto` files at `smart_control/proto`into python files by running `cd smart_control/proto && protoc --python_out=. smart_control_building.proto smart_control_normalization.proto smart_control_reward.proto && cd ../..`

7. Now in the `notebooks/SAC_Demo.ipynb` notebook, modify the values of `data_path`, `metrics_path`, `output_data_path` and `root_dir`. In particular, `data_path` should point to the `sim_config.gin` file at `smart_control/configs/sim_config.gin`
7. Modify the value of `VIDEO_PATH_ROOT` at `smart_control/simulator/constants.py`. This is the path where simulation videos will be stored

8. Now you are ready to run the `notebooks/SAC_Demo.ipynb` notebook
8. Now in the `notebooks/SAC_Demo.ipynb` notebook, modify the values of `data_path`, `metrics_path`, `output_data_path` and `root_dir`. In particular, `data_path` should point to the `sim_config.gin` file at `smart_control/configs/sim_config.gin`

9. Now you are ready to run the `notebooks/SAC_Demo.ipynb` notebook

## Real World Data

Expand Down
381 changes: 370 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ scikit-learn = "^1.5.1"
ipykernel = "^6.29.5"
typing-extensions = "^4.12.2"
ipython = "^8.27.0"
mctspy = "^0.1.1"
tqdm = "^4.67.0"
wrapt = "1.14.1"
pyyaml = "^6.0.2"
pytest = "^8.3.5"


[build-system]
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

625 changes: 416 additions & 209 deletions smart_control/notebooks/SAC_Demo.ipynb

Large diffs are not rendered by default.

Empty file.
8 changes: 8 additions & 0 deletions smart_control/refactor/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from smart_control.refactor.agents.base_agent import BaseAgent, TFAgentWrapper
from smart_control.refactor.agents.sac_agent import create_sac_agent

__all__ = [
'BaseAgent',
'TFAgentWrapper',
'create_sac_agent',
]
137 changes: 137 additions & 0 deletions smart_control/refactor/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import abc
import logging
from typing import Any, Dict

import tensorflow as tf
from tf_agents.agents import tf_agent
from tf_agents.policies import tf_policy


logger = logging.getLogger(__name__)


class BaseAgent(abc.ABC):
"""Abstract base class for all RL agents.

This class defines the core interface that all agents must implement.
"""

@abc.abstractmethod
def initialize(self) -> None:
"""Initialize the agent.

This method should be called before using the agent.
"""
pass

@abc.abstractmethod
def train(self, experience) -> Dict[str, Any]:
"""Train the agent on a batch of experience.

Args:
experience: A batch of experience data for training.

Returns:
A dictionary of loss metrics from training.
"""
pass

@property
@abc.abstractmethod
def policy(self) -> tf_policy.TFPolicy:
"""Returns the agent's main policy."""
pass

@property
@abc.abstractmethod
def collect_policy(self) -> tf_policy.TFPolicy:
"""Returns the agent's collection policy."""
pass

@property
@abc.abstractmethod
def collect_data_spec(self):
"""Returns the agent's data collection specification."""
pass

@property
@abc.abstractmethod
def train_step_counter(self) -> tf.Variable:
"""Returns the agent's training step counter."""
pass


class TFAgentWrapper(BaseAgent):
"""Wrapper class for TF-Agents agents to conform to BaseAgent interface."""

def __init__(self, tf_agent_instance: tf_agent.TFAgent):
"""Initialize with a TF-Agents agent instance.

Args:
tf_agent_instance: A TF-Agents agent instance.
"""
self._agent = tf_agent_instance

def initialize(self) -> None:
"""Initialize the agent."""
self._agent.initialize()

def train(self, experience) -> Dict[str, Any]:
"""Train the agent on a batch of experience.

Args:
experience: A batch of experience data for training.

Returns:
A dictionary of loss metrics from training.
"""
loss_info = self._agent.train(experience)

result = {'loss': loss_info.loss}

# Handle different types of extra info that might be returned by different agents
if hasattr(loss_info, 'extra'):
logger.info('Extra loss info found in agent training result')
logger.info('Extra loss info type: %s', loss_info.extra)
extra = loss_info.extra

# SAC agent's extra is a LossInfo with fields like actor_loss, critic_loss, alpha_loss
extra_dict = {}
for attr_name in dir(extra):
# Skip private attributes and methods
if not attr_name.startswith('_') and not callable(getattr(extra, attr_name)):
attr_value = getattr(extra, attr_name)
# Convert TensorFlow tensors to numpy arrays
if hasattr(attr_value, 'numpy'):
extra_dict[attr_name] = attr_value.numpy()
else:
extra_dict[attr_name] = attr_value

result['extra'] = extra_dict

return result

# Expose the underlying agent directly for checkpointing
@property
def agent(self):
return self._agent

@property
def policy(self) -> tf_policy.TFPolicy:
"""Returns the agent's main policy."""
return self._agent.policy

@property
def collect_policy(self) -> tf_policy.TFPolicy:
"""Returns the agent's collection policy."""
return self._agent.collect_policy

@property
def collect_data_spec(self):
"""Returns the agent's data collection specification."""
return self._agent.collect_data_spec

@property
def train_step_counter(self) -> tf.Variable:
"""Returns the agent's training step counter."""
return self._agent.train_step_counter
9 changes: 9 additions & 0 deletions smart_control/refactor/agents/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from smart_control.refactor.agents.networks.sac_networks import (
create_sequential_actor_network,
create_sequential_critic_network,
)

__all__ = [
'create_sequential_actor_network',
'create_sequential_critic_network',
]
151 changes: 151 additions & 0 deletions smart_control/refactor/agents/networks/sac_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Network architectures for SAC agent.

This module provides functions to create actor and critic networks for SAC agents.
"""

import functools
from typing import Callable, Optional, Sequence, Tuple

import tensorflow as tf
from tf_agents.networks import nest_map
from tf_agents.networks import sequential
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.keras_layers import inner_reshape
from tf_agents.typing import types


# Utility to create dense layers with consistent initialization and activation
dense = functools.partial(
tf.keras.layers.Dense,
activation=tf.keras.activations.relu,
kernel_initializer='glorot_uniform',
)


def create_fc_network(layer_units: Sequence[int]) -> tf.keras.Model:
"""Creates a fully connected network.

Args:
layer_units: A sequence of layer units.

Returns:
A sequential model of dense layers.
"""
return sequential.Sequential([dense(num_units) for num_units in layer_units])


def create_identity_layer() -> tf.keras.layers.Layer:
"""Creates an identity layer.

Returns:
A Lambda layer that returns its input.
"""
return tf.keras.layers.Lambda(lambda x: x)


def create_sequential_critic_network(
obs_fc_layer_units: Sequence[int],
action_fc_layer_units: Sequence[int],
joint_fc_layer_units: Sequence[int]
) -> sequential.Sequential:
"""Create a sequential critic network for SAC.

Args:
obs_fc_layer_units: Units for observation network layers.
action_fc_layer_units: Units for action network layers.
joint_fc_layer_units: Units for joint network layers.

Returns:
A sequential critic network.
"""
# Split the inputs into observations and actions.
def split_inputs(inputs):
return {'observation': inputs[0], 'action': inputs[1]}

# Create an observation network.
obs_network = (
create_fc_network(obs_fc_layer_units)
if obs_fc_layer_units
else create_identity_layer()
)

# Create an action network.
action_network = (
create_fc_network(action_fc_layer_units)
if action_fc_layer_units
else create_identity_layer()
)

# Create a joint network.
joint_network = (
create_fc_network(joint_fc_layer_units)
if joint_fc_layer_units
else create_identity_layer()
)

# Final layer.
value_layer = tf.keras.layers.Dense(1, kernel_initializer='glorot_uniform')

return sequential.Sequential(
[
tf.keras.layers.Lambda(split_inputs),
nest_map.NestMap(
{'observation': obs_network, 'action': action_network}
),
nest_map.NestFlatten(),
tf.keras.layers.Concatenate(),
joint_network,
value_layer,
inner_reshape.InnerReshape(current_shape=[1], new_shape=[]),
],
name='sequential_critic',
)


class _TanhNormalProjectionNetworkWrapper(
tanh_normal_projection_network.TanhNormalProjectionNetwork
):
"""Wrapper to pass predefined `outer_rank` to underlying projection net."""

def __init__(self, sample_spec, predefined_outer_rank=1):
super(_TanhNormalProjectionNetworkWrapper, self).__init__(sample_spec)
self.predefined_outer_rank = predefined_outer_rank

def call(self, inputs, network_state=(), **kwargs):
kwargs['outer_rank'] = self.predefined_outer_rank
if 'step_type' in kwargs:
del kwargs['step_type']
return super(_TanhNormalProjectionNetworkWrapper, self).call(
inputs, **kwargs
)


def create_sequential_actor_network(
actor_fc_layers: Sequence[int],
action_tensor_spec: types.NestedTensorSpec,
) -> sequential.Sequential:
"""Create a sequential actor network for SAC.

Args:
actor_fc_layers: Units for actor network fully connected layers.
action_tensor_spec: The action tensor spec.

Returns:
A sequential actor network.
"""
def tile_as_nest(non_nested_output):
return tf.nest.map_structure(
lambda _: non_nested_output, action_tensor_spec
)

return sequential.Sequential(
[dense(num_units) for num_units in actor_fc_layers]
+ [tf.keras.layers.Lambda(tile_as_nest)]
+ [
nest_map.NestMap(
tf.nest.map_structure(
_TanhNormalProjectionNetworkWrapper, action_tensor_spec
)
)
]
)
Loading