cambrian.ml.features_extractors

This module contains custom feature extractors for use in the models.

Classes

MjCambrianCombinedExtractor

Overwrite of the default feature extractor of Stable Baselines 3.

PermutedFlattenExtractor

Feature extract that flatten the input.

MjCambrianImageFeaturesExtractor

This is a feature extractor for images. Will implement an image queue for

MjCambrianMLPExtractor

MLP feature extractor for small images. Essentially NatureCNN but with MLPs.

MjCambrianNatureCNNExtractor

Nature CNN feature extractor for images. This is the default feature extractor

Functions

is_image_space(observation_space[, check_channels, ...])

This is an extension of the sb3 is_image_space to support both regular images

maybe_transpose_space(observation_space)

This is an extension of the sb3 maybe_transpose_space to support both regular

maybe_transpose_obs(observation)

This is an extension of the sb3 maybe_transpose_obs to support both regular

Module Contents

is_image_space(observation_space, check_channels=False, normalized_image=False)[source]

This is an extension of the sb3 is_image_space to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC).

maybe_transpose_space(observation_space)[source]

This is an extension of the sb3 maybe_transpose_space to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call maybe_transpose_space on the 3D case, but not the 4D.

maybe_transpose_obs(observation)[source]

This is an extension of the sb3 maybe_transpose_obs to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call maybe_transpose_obs on the 3D case, but not the 4D.

Note

In this case, there is a batch dimension, so the observation is 5D.

class MjCambrianCombinedExtractor(observation_space, *, normalized_image, image_extractor, share_image_extractor=False)[source]

Bases: stable_baselines3.common.torch_layers.BaseFeaturesExtractor

Overwrite of the default feature extractor of Stable Baselines 3.

class PermutedFlattenExtractor(observation_space)[source]

Bases: stable_baselines3.common.torch_layers.FlattenExtractor

Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed.

Parameters:

observation_space – The observation space of the environment

class MjCambrianImageFeaturesExtractor(observation_space, features_dim, activation)[source]

Bases: stable_baselines3.common.torch_layers.BaseFeaturesExtractor

This is a feature extractor for images. Will implement an image queue for temporal features. Should be inherited by other classes.

class MjCambrianMLPExtractor(observation_space, features_dim, activation, architecture)[source]

Bases: MjCambrianImageFeaturesExtractor

MLP feature extractor for small images. Essentially NatureCNN but with MLPs.

class MjCambrianNatureCNNExtractor(observation_space, features_dim, activation)[source]

Bases: MjCambrianImageFeaturesExtractor

Nature CNN feature extractor for images. This is the default feature extractor for stable baseline3 images. The main differences between this and the original is that this supports temporal features (i.e. image stacks) and dynamically calculates the kernel sizes and strides. In sb3, the fixed kernel sizes and strides restricted the image size to be > 36x36, which is a bd assumption here.