Source code for cambrian.ml.features_extractors
"""This module contains custom feature extractors for use in the models."""
from typing import Dict, List
import gymnasium as gym
import torch
from gymnasium import spaces
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
)
from stable_baselines3.common.type_aliases import TensorDict
# ==================
# Utils
[docs]
def is_image_space(
observation_space: gym.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""This is an extension of the sb3 is_image_space to support both regular images
(HxWxC) and images with an additional dimension (NxHxWxC)."""
from stable_baselines3.common.preprocessing import (
is_image_space as sb3_is_image_space,
)
return len(observation_space.shape) == 4 or sb3_is_image_space(
observation_space, normalized_image=normalized_image
)
[docs]
def maybe_transpose_space(observation_space: spaces.Box) -> spaces.Box:
"""This is an extension of the sb3 maybe_transpose_space to support both regular
images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call
maybe_transpose_space on the 3D case, but not the 4D."""
if len(observation_space.shape) == 4:
num, height, width, channels = observation_space.shape
new_shape = (num, channels, height, width)
observation_space = spaces.Box(
low=observation_space.low.reshape(new_shape),
high=observation_space.high.reshape(new_shape),
dtype=observation_space.dtype,
)
return observation_space
[docs]
def maybe_transpose_obs(observation: torch.Tensor) -> torch.Tensor:
"""This is an extension of the sb3 maybe_transpose_obs to support both regular
images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call
maybe_transpose_obs on the 3D case, but not the 4D.
Note:
In this case, there is a batch dimension, so the observation is 5D.
"""
if len(observation.shape) == 5:
observation = observation.permute(0, 1, 4, 2, 3) # [B, T, N, H, W]
return observation
# ==================
# Feature Extractors
[docs]
class MjCambrianCombinedExtractor(BaseFeaturesExtractor):
"""Overwrite of the default feature extractor of Stable Baselines 3."""
def __init__(
self,
observation_space: spaces.Dict,
*,
normalized_image: bool,
image_extractor: BaseFeaturesExtractor,
share_image_extractor: bool = False,
) -> None:
# We do not know features-dim here before going over all the items, so put
# something there.
super().__init__(observation_space, features_dim=1)
self._image_extractor = None
if share_image_extractor:
# Verify all the image spaces have the same shape
image_space = None
for subspace in observation_space.values():
if is_image_space(subspace, normalized_image=normalized_image):
subspace = maybe_transpose_space(subspace)
if image_space is None:
image_space = subspace
assert image_space.shape == subspace.shape, (
"All the image spaces must have the same shape if "
+ "using shared image extractor"
)
assert image_space is not None, "There must be at least one image space"
self._image_extractor = image_extractor(image_space)
extractors: Dict[str, BaseFeaturesExtractor] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace, normalized_image=normalized_image):
subspace = maybe_transpose_space(subspace)
if share_image_extractor:
extractors[key] = self._image_extractor
else:
extractors[key] = image_extractor(subspace)
else:
# The observation key is a vector, flatten it if needed
extractors[key] = FlattenExtractor(subspace)
total_concat_size += extractors[key].features_dim
self.extractors = torch.nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations: TensorDict) -> torch.Tensor:
encoded_tensor_list = []
for key, extractor in self.extractors.items():
obs = maybe_transpose_obs(observations[key])
encoded_tensor_list.append(extractor(obs))
return torch.cat(encoded_tensor_list, dim=1)
[docs]
class PermutedFlattenExtractor(FlattenExtractor):
def forward(self, observations: torch.Tensor) -> torch.Tensor:
flattened = super().forward(observations)
perm = torch.randperm(flattened.size(-1))
return flattened[:, perm]
[docs]
class MjCambrianImageFeaturesExtractor(BaseFeaturesExtractor):
"""This is a feature extractor for images. Will implement an image queue for
temporal features. Should be inherited by other classes."""
def __init__(
self,
observation_space: gym.Space,
features_dim: int,
activation: torch.nn.Module,
):
super().__init__(observation_space, features_dim)
self._queue_size = 1
if len(observation_space.shape) == 4:
self._queue_size = observation_space.shape[0]
height, width, n_channels = observation_space.shape[-3:]
self._num_pixels = n_channels * height * width
self.temporal_linear = torch.nn.Sequential(
torch.nn.Linear(features_dim * self._queue_size, features_dim),
activation(),
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.temporal_linear(observations)
[docs]
class MjCambrianMLPExtractor(MjCambrianImageFeaturesExtractor):
"""MLP feature extractor for small images. Essentially NatureCNN but with MLPs."""
def __init__(
self,
observation_space: gym.Space,
features_dim: int,
activation: torch.nn.Module,
architecture: List[int],
) -> None:
super().__init__(observation_space, features_dim, activation)
layers = []
layers.append(torch.nn.Flatten())
layers.append(torch.nn.Linear(self._num_pixels, architecture[0]))
layers.append(activation())
for i in range(1, len(architecture)):
layers.append(torch.nn.Linear(architecture[i - 1], architecture[i]))
layers.append(activation())
layers.append(torch.nn.Linear(architecture[-1], features_dim))
layers.append(activation())
self.mlp = torch.nn.Sequential(*layers)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
B = observations.shape[0]
observations = observations.reshape(-1, self._num_pixels) # [B, C * H * W]
encodings = self.mlp(observations)
encodings = encodings.reshape(B, -1)
return super().forward(encodings)
[docs]
class MjCambrianNatureCNNExtractor(MjCambrianImageFeaturesExtractor):
"""Nature CNN feature extractor for images. This is the default feature extractor
for stable baseline3 images. The main differences between this and the original
is that this supports temporal features (i.e. image stacks) and dynamically
calculates the kernel sizes and strides. In sb3, the fixed kernel sizes and strides
restricted the image size to be > 36x36, which is a bd assumption here."""
def __init__(
self, observation_space: gym.Space, features_dim, activation: torch.nn.Module
):
super().__init__(observation_space, features_dim, activation)
n_channels = observation_space.shape[1]
width, height = observation_space.shape[2], observation_space.shape[3]
# Dynamically calculate kernel sizes and strides
k_sizes, strides = self.calculate_dynamic_params(width, height)
# Create CNN layers
self.cnn = torch.nn.Sequential(
torch.nn.Conv2d(n_channels, 32, kernel_size=k_sizes[0], stride=strides[0]),
activation(),
torch.nn.Conv2d(32, 64, kernel_size=k_sizes[1], stride=strides[1]),
activation(),
torch.nn.Conv2d(64, 64, kernel_size=k_sizes[2], stride=strides[2]),
activation(),
torch.nn.Flatten(),
)
# Compute shape by doing one forward pass
with torch.no_grad():
sample = torch.as_tensor(observation_space.sample())
n_flatten = self.cnn(sample)
n_flatten = n_flatten.shape[0] * n_flatten.shape[1]
self.linear = torch.nn.Sequential(
torch.nn.Linear(n_flatten, self._queue_size * features_dim), activation()
)
def calculate_dynamic_params(self, width, height):
# Define max sizes and strides (from sb3, i.e. if width x height > 36x36, it's
# the same).
max_kernel_sizes = [8, 4, 2]
max_strides = [4, 2, 1]
# Adjust kernel sizes and strides based on input dimensions
kernel_sizes = [min(k, height, width) for k in max_kernel_sizes]
strides = [
min(s, height // k, width // k) for s, k in zip(max_strides, kernel_sizes)
]
return kernel_sizes, strides
def forward(self, observations: torch.Tensor) -> torch.Tensor:
B = observations.shape[0]
observations = observations.reshape(-1, *observations.shape[2:])
observations = self.cnn(observations)
observations = observations.reshape(B, -1)
return super().forward(self.linear(observations))