cambrian.ml.features_extractors¶
This module contains custom feature extractors for use in the models.
Classes¶
Overwrite of the default feature extractor of Stable Baselines 3. |
|
Feature extract that flatten the input. |
|
This is a feature extractor for images. Will implement an image queue for |
|
MLP feature extractor for small images. Essentially NatureCNN but with MLPs. |
|
Nature CNN feature extractor for images. This is the default feature extractor |
Functions¶
|
This is an extension of the sb3 is_image_space to support both regular images |
|
This is an extension of the sb3 maybe_transpose_space to support both regular |
|
This is an extension of the sb3 maybe_transpose_obs to support both regular |
Module Contents¶
- is_image_space(observation_space, check_channels=False, normalized_image=False)[source]¶
This is an extension of the sb3 is_image_space to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC).
- maybe_transpose_space(observation_space)[source]¶
This is an extension of the sb3 maybe_transpose_space to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call maybe_transpose_space on the 3D case, but not the 4D.
- maybe_transpose_obs(observation)[source]¶
This is an extension of the sb3 maybe_transpose_obs to support both regular images (HxWxC) and images with an additional dimension (NxHxWxC). sb3 will call maybe_transpose_obs on the 3D case, but not the 4D.
Note
In this case, there is a batch dimension, so the observation is 5D.
- class MjCambrianCombinedExtractor(observation_space, *, normalized_image, image_extractor, share_image_extractor=False)[source]¶
Bases:
stable_baselines3.common.torch_layers.BaseFeaturesExtractor
Overwrite of the default feature extractor of Stable Baselines 3.
- class PermutedFlattenExtractor(observation_space)[source]¶
Bases:
stable_baselines3.common.torch_layers.FlattenExtractor
Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed.
- Parameters:
observation_space – The observation space of the environment
- class MjCambrianImageFeaturesExtractor(observation_space, features_dim, activation)[source]¶
Bases:
stable_baselines3.common.torch_layers.BaseFeaturesExtractor
This is a feature extractor for images. Will implement an image queue for temporal features. Should be inherited by other classes.
- class MjCambrianMLPExtractor(observation_space, features_dim, activation, architecture)[source]¶
Bases:
MjCambrianImageFeaturesExtractor
MLP feature extractor for small images. Essentially NatureCNN but with MLPs.
- class MjCambrianNatureCNNExtractor(observation_space, features_dim, activation)[source]¶
Bases:
MjCambrianImageFeaturesExtractor
Nature CNN feature extractor for images. This is the default feature extractor for stable baseline3 images. The main differences between this and the original is that this supports temporal features (i.e. image stacks) and dynamically calculates the kernel sizes and strides. In sb3, the fixed kernel sizes and strides restricted the image size to be > 36x36, which is a bd assumption here.