Source code for cambrian.utils.wrappers

"""Wrappers for the MjCambrianEnv. Used during training."""

from types import NoneType
from typing import Any, Callable, Dict, List, Optional, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.wrappers.numpy_to_torch import numpy_to_torch, torch_to_numpy
from stable_baselines3.common.env_checker import check_env

from cambrian.envs import MjCambrianEnv, MjCambrianEnvConfig
from cambrian.utils import device, is_integer
from cambrian.utils.types import (
    ActionType,
    InfoType,
    ObsType,
    RenderFrame,
    RewardType,
    TerminatedType,
    TruncatedType,
)


[docs] class MjCambrianSingleAgentEnvWrapper(gym.Wrapper): """Wrapper around the MjCambrianEnv that acts as if there is a single agent. Will replace all multi-agent methods to just use the first agent. Keyword Args: agent_name: The name of the agent to use. If not provided, the first agent will be used. """ def __init__( self, env: MjCambrianEnv, *, agent_name: Optional[str] = None, combine_rewards: bool = True, combine_terminated: bool = True, combine_truncated: bool = True, ): super().__init__(env) self._combine_rewards = combine_rewards self._combine_terminated = combine_terminated self._combine_truncated = combine_truncated agent_name = agent_name or next(iter(env.agents.keys())) assert agent_name in env.agents, f"agent {agent_name} not found." self._agent = env.agents[agent_name]
[docs] self.action_space = self._agent.action_space
[docs] self.observation_space = self._agent.observation_space
[docs] def reset(self, *args, **kwargs) -> Tuple[ObsType, InfoType]: obs, info = self.env.reset(*args, **kwargs) return obs[self._agent.name], info[self._agent.name]
[docs] def step( self, action: ActionType ) -> Tuple[ObsType, RewardType, TerminatedType, TruncatedType, InfoType]: action = {self._agent.name: action} obs, reward, terminated, truncated, info = self.env.step(action) obs = obs[self._agent.name] info = info[self._agent.name] if self._combine_rewards: reward = sum(list(reward.values())) else: reward = reward[self._agent.name] if self._combine_terminated: terminated = any(terminated.values()) else: terminated = terminated[self._agent.name] if self._combine_truncated: truncated = any(truncated.values()) else: truncated = truncated[self._agent.name] return obs, reward, terminated, truncated, info
[docs] class MjCambrianPettingZooEnvWrapper(gym.Wrapper): """Wrapper around the MjCambrianEnv that acts as if there is a single agent, where in actuality, there's multi-agents. SB3 doesn't support Dict action spaces, so this wrapper will flatten the action into a single space. The observation can be a dict; however, nested dicts are not allowed. """ def __init__(self, env: MjCambrianEnv): super().__init__(env) self.env: MjCambrianEnv
[docs] def reset(self, *args, **kwargs) -> Tuple[ObsType, InfoType]: obs, info = self.env.reset(*args, **kwargs) # Flatten the observations flattened_obs: Dict[str, Any] = {} for agent_name, agent_obs in obs.items(): if isinstance(agent_obs, dict): for key, value in agent_obs.items(): flattened_obs[f"{agent_name}_{key}"] = value else: flattened_obs[agent_name] = agent_obs return flattened_obs, info
[docs] def step( self, action: ActionType ) -> Tuple[ObsType, RewardType, TerminatedType, TruncatedType, InfoType]: # Convert the action back to a dict action = action.reshape(-1, len(self.env.agents)) action = { agent_name: action[:, i] for i, agent_name in enumerate(self.env.agents.keys()) if self.env.agents[agent_name].config.trainable } obs, reward, terminated, truncated, info = self.env.step(action) # Accumulate the rewards, terminated, and truncated reward = sum(reward.values()) terminated = any(terminated.values()) truncated = any(truncated.values()) # Flatten the observations flattened_obs: Dict[str, Any] = {} for agent_name, agent_obs in obs.items(): if isinstance(agent_obs, dict): for key, value in agent_obs.items(): flattened_obs[f"{agent_name}_{key}"] = value else: flattened_obs[agent_name] = agent_obs return flattened_obs, reward, terminated, truncated, info
@property
[docs] def observation_space(self) -> gym.spaces.Dict: """SB3 doesn't support nested Dict observation spaces, so we'll flatten it. If each agent has a Dict observation space, we'll flatten it into a single observation where the key in the dict is the agent name and the original space name.""" observation_space: Dict[str, gym.Space] = {} for agent in self.env.agents.values(): agent_observation_space = agent.observation_space if isinstance(agent_observation_space, gym.spaces.Dict): for key, value in agent_observation_space.spaces.items(): observation_space[f"{agent.name}_{key}"] = value else: observation_space[agent.name] = agent_observation_space return gym.spaces.Dict(observation_space)
@property
[docs] def action_space(self) -> gym.spaces.Box: """The only gym.Space that SB3 supports that's continuous for the action space is a Box. We can assume each agent's action space is a Box, so we'll flatten each action space into one Box for the environment. Assumptions: - All agents have the same number of actions - All actions have the same shape - All actions are continuous - All actions are normalized between -1 and 1 """ # Get the first agent's action space first_agent_name = next(iter(self.env.agents.keys())) first_agent_action_space = self.env.agents[first_agent_name].action_space # Check if the action space is continuous assert isinstance(first_agent_action_space, gym.spaces.Box), ( "SB3 only supports continuous action spaces for the environment. " f"agent {first_agent_name} has a {type(first_agent_action_space)}" " action space." ) # Get the shape of the action space shape = first_agent_action_space.shape low = first_agent_action_space.low high = first_agent_action_space.high # Check if all agents have the same number of actions for agent_name, agent_action_space in self.env.action_spaces.items(): assert shape == agent_action_space.shape, ( "All agents must have the same number of actions. " f"agent {first_agent_name} has {shape} actions, but {agent_name} " f"has {agent_action_space.shape} actions." ) # Check if the action space is continuous assert isinstance(agent_action_space, gym.spaces.Box), ( "SB3 only supports continuous action spaces for the environment. " f"agent {first_agent_name} has a " f"{type(first_agent_action_space)} action space." ) assert all(low == agent_action_space.low), ( "All actions must have the same low value. " f"agent {first_agent_name} has a low value of {low}, " f"but {agent_name} has a low value of {agent_action_space.low}." ) assert all(high == agent_action_space.high), ( "All actions must have the same high value. " f"agent {first_agent_name} has a high value of {high}, " f"but {agent_name} has a high value of {agent_action_space.high}." ) low = np.tile(low, len(self.env.agents)) high = np.tile(high, len(self.env.agents)) shape = (shape[0] * len(self.env.agents),) return gym.spaces.Box( low=low, high=high, shape=shape, dtype=first_agent_action_space.dtype )
[docs] class MjCambrianConstantActionWrapper(gym.Wrapper): """This wrapper will apply a constant action at specific indices of the action space. Args: constant_actions: A dictionary where the keys are the indices of the action space and the values are the constant actions to apply. """ def __init__(self, env: MjCambrianEnv, constant_actions: Dict[Any, Any]): super().__init__(env) self._constant_action_indices = [ int(k) if is_integer(k) else k for k in constant_actions.keys() ] self._constant_action_values = list(constant_actions.values())
[docs] def step( self, action: ActionType ) -> Tuple[ObsType, RewardType, TerminatedType, TruncatedType, InfoType]: if isinstance(action, dict): assert all(idx in action for idx in self._constant_action_indices), ( "The constant action indices must be in the action space." f"Indices: {self._constant_action_indices}, Action space: {action}" ) action[self._constant_action_indices] = self._constant_action_values return self.env.step(action)
@torch_to_numpy.register(np.ndarray) def _(value: np.ndarray) -> np.ndarray: return value @torch_to_numpy.register(NoneType) def _(value: NoneType) -> NoneType: return value
[docs] class MjCambrianTorchToNumpyWrapper(gym.Wrapper): """Wraps a torch-based environment to convert inputs and outputs to NumPy arrays.""" def __init__(self, env: gym.Env, *, convert_action: bool = False): """Wrapper class to change inputs and outputs of environment to numpy arrays. Args: env: The torch-based environment Keyword Args: convert_action: Whether to convert the action to a numpy array """ super().__init__(env) self._convert_action = convert_action
[docs] def step( self, actions: ActionType ) -> Tuple[ObsType, RewardType, TerminatedType, TruncatedType, InfoType]: """Using a numpy-based action that is converted to torch to be used by the environment. Args: action: A numpy-based action Returns: The numpy-based observation, reward, termination, truncation, and extra info """ actions = ( numpy_to_torch(actions, device=device) if self._convert_action else actions ) obs, reward, terminated, truncated, info = self.env.step(actions) return ( torch_to_numpy(obs), reward, terminated, truncated, torch_to_numpy(info), )
[docs] def reset( self, *, seed: int | list[int] | None = None, options: dict[str, Any] | None = None, ) -> Tuple[ObsType, InfoType]: """Resets the environment returning numpy-based observations and info. Args: seed: The seed for resetting the environment options: The options for resetting the environment Returns: The numpy-based observation and extra info """ if options: options = numpy_to_torch(options, device=device) obs, info = self.env.reset(seed=seed, options=options) return torch_to_numpy(obs), torch_to_numpy(info)
[docs] def render(self) -> RenderFrame | List[RenderFrame] | None: """Renders the environment returning a numpy-based image. Returns: The numpy-based image """ return torch_to_numpy(self.env.render())
[docs] def make_wrapped_env( config: MjCambrianEnvConfig, wrappers: List[Callable[[gym.Env], gym.Env]], seed: Optional[int] = None, **kwargs, ) -> gym.Env: """Utility function for creating a MjCambrianEnv.""" def _init(): env = config.instance(config, **kwargs) for wrapper in wrappers: env = wrapper(env) # check_env will call reset and set the seed to 0; call set_random_seed after check_env(env, warn=False) env.unwrapped.set_random_seed(seed) return env return _init