"""Defines the MjCambrianEnv class."""
import pickle
import time
from collections import deque
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Self, Tuple
import mujoco as mj
import numpy as np
from gymnasium import Env, spaces
from hydra_config import HydraContainerConfig, config_wrapper
from pettingzoo import ParallelEnv
from cambrian.agents.agent import MjCambrianAgent, MjCambrianAgentConfig
from cambrian.renderer import (
MjCambrianRenderer,
MjCambrianRendererConfig,
MjCambrianRendererSaveMode,
)
from cambrian.renderer.overlays import MjCambrianCursor, MjCambrianViewerOverlay
from cambrian.utils.cambrian_xml import MjCambrianXML, MjCambrianXMLConfig
from cambrian.utils.logger import get_logger
from cambrian.utils.spec import MjCambrianSpec, spec_from_xml
from cambrian.utils.types import (
ActionType,
InfoType,
MjCambrianRewardFn,
MjCambrianStepFn,
MjCambrianTerminationFn,
MjCambrianTruncationFn,
ObsType,
RenderFrame,
RewardType,
TerminatedType,
TruncatedType,
)
# ======================
@config_wrapper
[docs]
class MjCambrianEnvConfig(HydraContainerConfig):
"""Defines a config for the cambrian environment.
Attributes:
instance (Callable[[Self], "MjCambrianEnv"]): The class method to use to
instantiate the environment.
xml (MjCambrianXMLConfig): The xml for the scene. This is the xml that will be
used to create the environment. See `MjCambrianXML` for more info.
step_fn (MjCambrianStepFn): The step function to use. See the `MjCambrianStepFn`
for more info. The step fn is called before the termination, truncation, and
reward fns, and after the action has been applied to the agents. It takes
the environment, the observations, the info dict, and any additional kwargs.
Returns the updated observations and info dict.
termination_fn (MjCambrianTerminationFn): The termination function to use. See
the :class:`MjCambrianTerminationFn` for more info.
truncation_fn (MjCambrianTruncationFn): The truncation function to use. See the
:class:`MjCambrianTruncationFn` for more info.
reward_fn (MjCambrianRewardFn): The reward function type to use. See the
:class:`MjCambrianRewardFn` for more info.
frame_skip (int): The number of mujoco simulation steps per `gym.step()` call.
max_episode_steps (int): The maximum number of steps per episode.
n_eval_episodes (int): The number of episodes to evaluate for.
add_overlays (bool): Whether to add overlays or not.
clear_overlays_on_reset (bool): Whether to clear the overlays on reset or not.
Consequence of setting to False is that when drawing position overlays
and when mazes change between evaluations, the sites will be drawn on top
of each other which may not be desired. When record is False, the overlays
are always be cleared.
debug_overlays_size (float): The size of the debug overlays. This is a
percentage of the total renderer size. If 0, debug overlays are disabled.
renderer (Optional[MjCambrianViewerConfig]): The default viewer config to
use for the mujoco viewer. If unset, no renderer will be used. Should
set to None if `render` will never be called. This may be useful to
reduce the amount of vram consumed by non-rendering environments.
save_filename (Optional[str]): The filename to save recordings to. This is more
of a placeholder for external scripts to use, if desired.
agents (List[MjCambrianAgentConfig]): The configs for the agents.
The key will be used as the default name for the agent, unless explicitly
set in the agent config.
"""
instance: Callable[[Self], "MjCambrianEnv"]
xml: MjCambrianXMLConfig
step_fn: MjCambrianStepFn
termination_fn: MjCambrianTerminationFn
truncation_fn: MjCambrianTruncationFn
reward_fn: MjCambrianRewardFn
frame_skip: int
max_episode_steps: int
n_eval_episodes: int
add_overlays: bool
clear_overlays_on_reset: bool
debug_overlays_size: float
renderer: Optional[MjCambrianRendererConfig] = None
save_filename: Optional[str] = None
agents: Dict[str, MjCambrianAgentConfig | Any]
[docs]
class MjCambrianEnv(ParallelEnv, Env):
"""A MjCambrianEnv defines a gymnasium environment that's based off mujoco.
NOTES:
- This is an overridden version of the MujocoEnv class. The two main differences is
that we allow for /reset multiple agents and use our own custom renderer. It also
reduces the need to create temporary xml files which MujocoEnv had to load. It's
essentially a copy of MujocoEnv with the two aforementioned major changes.
Args:
config (MjCambrianEnvConfig): The config object.
name (Optional[str]): The name of the environment. This is added as an overlay
to the renderer.
"""
metadata = {"render_modes": ["human", "rgb_array"]}
def __init__(self, config: MjCambrianEnvConfig, name: Optional[str] = None):
self._config = config
self._name = name or self.__class__.__name__
self._agents: Dict[str, MjCambrianAgent] = {}
self._create_agents()
self._xml = self.generate_xml()
try:
self._spec = spec_from_xml(self._xml)
self._spec.compile()
except Exception:
get_logger().error(f"Error creating model\n{self._xml.to_string()}")
raise
self._spec.env = self
self.render_mode = "rgb_array"
self._renderer: MjCambrianRenderer = None
if renderer_config := self._config.renderer:
if "human" in self._config.renderer.render_modes:
self.render_mode = "human"
self._renderer = MjCambrianRenderer(renderer_config)
self._episode_step = 0
self._max_episode_steps = self._config.max_episode_steps
self._num_resets = 0
self._num_timesteps = 0
self._stashed_cumulative_reward = 0
self._cumulative_reward = 0
self._timings = deque(maxlen=25)
self._record: bool = False
self._rollout: Dict[str, Any] = {}
self._overlays: Dict[str, Any] = {}
# We'll store the info dict as a state within this class so that the truncation,
# termination, and reward functions can use it for keeping a state. Like passing
# the info dict to these functions allows them to edit them and keep around
# information that is helpful for subsequent calls. It will always be reset
# during the reset method and will only be maintained during an episode length.
# Because the info dict is treated as stateful, take care in not adding new keys
# on each step, as this will cause the info dict to grow until the end of the
# episode.
self._info: Dict[str, Dict[str, Any]]
def _create_agents(self):
"""Helper method to create the agents."""
for name, agent_config in self._config.agents.items():
assert name not in self._agents, f"Agent {name} already exists."
self._agents[name] = agent_config.instance(agent_config, name)
[docs]
def generate_xml(self) -> MjCambrianXML:
"""Generates the xml for the environment.
.. todo::
Can we update to use MjSpec?
"""
xml = MjCambrianXML.from_string(self._config.xml)
# Add the agents to the xml
for agent in self._agents.values():
xml += agent.generate_xml()
return xml
[docs]
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[Any, Any]] = None
) -> Tuple[ObsType, InfoType]:
"""Reset the environment.
Will reset all underlying components (the maze, the agents, etc.). The
simulation will then be stepped once to ensure that the observations are
up-to-date.
Returns:
Tuple[ObsType, InfoType]: The observations for each
agent and the info dict for each agent.
"""
if seed is not None and self._num_resets == 0:
self.set_random_seed(seed)
# First, reset the mujoco simulation
mj.mj_resetData(self._spec.model, self._spec.data)
# Reset the info dict. We'll update the stateful info dict here, as well.
info: Dict[str, Dict[str, Any]] = {a: {} for a in self._agents}
self._info = info
# Then, reset the agents
obs: Dict[str, Dict[str, Any]] = {}
for name, agent in self._agents.items():
obs[name] = agent.reset(self._spec)
# Recompile the model/data
self._spec.recompile()
# 'll step the simulation once to allow for states to propagate
self._step_mujoco_simulation(1, info)
# Now update the info dict
if self._renderer is not None:
self._renderer.reset(self._spec)
# Update metadata variables
self._episode_step = 0
self._stashed_cumulative_reward = self._cumulative_reward
self._cumulative_reward = 0
self._num_resets += 1
self._timings.clear()
self._timings.append(time.time())
# Reset the caches
if not self._record:
self._rollout.clear()
self._overlays.clear()
elif self._config.clear_overlays_on_reset:
self._overlays.clear()
if self._record:
# TODO make this cleaner
self._rollout.setdefault("actions", [])
self._rollout["actions"].append(
[np.zeros_like(a.action_space.sample()) for a in self._agents.values()]
)
self._rollout.setdefault("positions", [])
self._rollout["positions"].append([a.qpos for a in self._agents.values()])
return self._config.step_fn(self, obs, info)
[docs]
def step(
self, action: ActionType
) -> Tuple[ObsType, RewardType, TerminatedType, TruncatedType, InfoType]:
"""Step the environment.
The dynamics is updated through the `_step_mujoco_simulation` method.
Args:
action (Dict[str, Any]): The action to take for each agent.
The keys define the agent name, and the values define the action for
that agent.
Returns:
Dict[str, Any]: The observations for each agent.
Dict[str, float]: The reward for each agent.
Dict[str, bool]: Whether each agent has terminated.
Dict[str, bool]: Whether each agent has truncated.
Dict[str, Dict[str, Any]]: The info dict for each agent.
"""
info = self._info
# First, apply the actions to the agents and step the simulation
for name, agent in self._agents.items():
if not agent.trainable or agent.config.use_privileged_action:
if not agent.trainable and name in action:
get_logger().warning(
f"Action for {name} found in action dict. "
"This will be overridden by the agent.",
extra={"once": True},
)
action[name] = agent.get_action_privileged(self)
assert name in action, f"Action for {name} not found in action dict."
agent.apply_action(action[name])
info[name]["prev_pos"] = agent.pos.copy()
info[name]["action"] = action[name]
# Then, step the mujoco simulation
self._step_mujoco_simulation(self._config.frame_skip, info)
# We'll then step each agent to render it's current state and get the obs
obs: Dict[str, Any] = {}
for name, agent in self._agents.items():
obs[name] = agent.step()
# Call helper methods to update the observations, rewards, terminated, and info
obs, info = self._config.step_fn(self, obs, info)
terminated = self._compute_terminated(info)
truncated = self._compute_truncated(info)
reward = self._compute_reward(terminated, truncated, info)
self._episode_step += 1
self._num_timesteps += 1
self._cumulative_reward += sum(reward.values())
if self._record:
self._rollout["actions"].append(list(action.values()))
self._rollout["positions"].append([a.pos for a in self._agents.values()])
if (
self._config.debug_overlays_size > 0
and self._record
or "human" in self._config.renderer.render_modes
):
self._overlays["Name"] = self._name
self._overlays["Total Timesteps"] = self.num_timesteps
self._overlays["Step"] = self._episode_step
self._overlays["Cumulative Reward"] = round(self._cumulative_reward, 2)
self._timings.append(time.time())
fps = (len(self._timings) - 1) / (self._timings[-1] - self._timings[0])
self._overlays["FPS"] = round(fps, 2)
return obs, reward, terminated, truncated, info
def _step_mujoco_simulation(self, n_frames: int, info: InfoType):
"""Sets the mujoco simulation. Will step the simulation `n_frames` times, each
time checking if the agent has contacts."""
# Initially set has_contacts to False for all agents
for name in self._agents:
info[name]["has_contacts"] = False
# Check contacts at _every_ step.
# NOTE: Doesn't process whether hits are terminal or not
for _ in range(n_frames):
mj.mj_step(self._spec.model, self._spec.data)
# Check for contacts. We won't break here, but we'll store whether an
# agent has contacts or not. If we didn't store during the simulation
# step, contact checking would only occur after the frame skip, meaning
# that, if during the course of the frame skip, the agent hits an object
# and then moves away, the contact would not be detected.
if self._spec.data.ncon > 0:
for name, agent in self._agents.items():
if not info[name]["has_contacts"]:
# Only check for has contacts if it hasn't been set to True
# This reduces redundant checks
info[name]["has_contacts"] = agent.has_contacts
def _compute_terminated(self, info: InfoType) -> TerminatedType:
"""Compute whether the env has terminated. Termination indicates success,
whereas truncated indicates failure.
The default implementation will always return False for all agents. This can
be overridden in subclasses to provide custom termination conditions.
"""
terminated: Dict[str, bool] = {}
for name, agent in self._agents.items():
terminated[name] = self._config.termination_fn(self, agent, info[name])
return terminated
def _compute_truncated(self, info: InfoType) -> TruncatedType:
"""Compute whether the env has terminated. Termination indicates success,
whereas truncated indicates failure.
The default implementation will always return False for all agents. This can
be overridden in subclasses to provide custom termination conditions.
"""
truncated: Dict[str, bool] = {}
for name, agent in self._agents.items():
truncated[name] = self._config.truncation_fn(self, agent, info[name])
return truncated
def _compute_reward(
self,
terminated: TerminatedType,
truncated: TruncatedType,
info: InfoType,
) -> RewardType:
"""Computes the reward for the environment.
Args:
terminated (TerminatedType): Whether each agent has terminated.
Termination indicates success (agent has reached the goal).
truncated (TruncatedType): Whether each agent has truncated.
Truncation indicates failure (agent has hit the wall or something).
info (InfoType): The info dict for each agent.
"""
rewards: Dict[str, float] = {}
for name, agent in self._agents.items():
rewards[name] = self._config.reward_fn(
self, agent, terminated[name], truncated[name], info[name]
)
return rewards
[docs]
def render(self) -> RenderFrame:
"""Renders the environment.
Returns:
RenderFrame: The rendered frame.
"""
assert self._renderer is not None, "Renderer has not been initialized! "
"Ensure `use_renderer` is set to True in the constructor."
overlays = []
if self._config.add_overlays:
cursor = MjCambrianCursor(
self._renderer.width,
self._renderer.height,
position=MjCambrianCursor.Position.TOP_LEFT,
)
for key, value in self._overlays.items():
if issubclass(type(value), MjCambrianViewerOverlay):
overlay = value
else:
overlay = MjCambrianViewerOverlay.create_text_overlay(
f"{key}: {value}"
)
cursor = overlay.place(cursor)
overlays.append(overlay)
if self._config.debug_overlays_size > 0:
overlays.extend(self._generate_overlays())
return self._renderer.render(overlays=overlays)
def _generate_overlays(self) -> List[MjCambrianViewerOverlay]:
overlays: List[MjCambrianViewerOverlay] = []
renderer_width = self._renderer.width
renderer_height = self._renderer.height
trainable_agents = {n: a for n, a in self._agents.items() if a.trainable}
num_agents = len(trainable_agents)
overlay_width = int(renderer_width // num_agents) if num_agents > 0 else 0
overlay_height = int(renderer_height * self._config.debug_overlays_size)
cursor = MjCambrianCursor(
overlay_width, overlay_height, position=MjCambrianCursor.Position.TOP_LEFT
)
for agent in self._agents.values():
agent_overlays = agent.render()
for overlay in agent_overlays:
cursor = overlay.place(cursor)
overlays.append(overlay)
return overlays
@property
[docs]
def name(self) -> str:
"""Returns the name of the environment."""
return self._name
@property
[docs]
def xml(self) -> MjCambrianXML:
"""Returns the xml for the environment."""
return self._xml
@property
[docs]
def agents(self) -> Dict[str, MjCambrianAgent]:
"""Returns the agents in the environment."""
return self._agents
@property
[docs]
def renderer(self) -> MjCambrianRenderer:
"""Returns the renderer for the environment."""
return self._renderer
@property
[docs]
def spec(self) -> MjCambrianSpec:
"""Returns the mujoco spec for the environment."""
return self._spec
@property
[docs]
def model(self) -> mj.MjModel:
"""Returns the mujoco model for the environment."""
return self._spec.model
@property
[docs]
def data(self) -> mj.MjData:
"""Returns the mujoco data for the environment."""
return self._spec.data
@property
[docs]
def episode_step(self) -> int:
"""Returns the current episode step."""
return self._episode_step
@property
[docs]
def num_timesteps(self) -> int:
"""Returns the number of timesteps."""
return self._num_timesteps
@property
[docs]
def max_episode_steps(self) -> int:
"""Returns the max episode steps."""
return self._max_episode_steps
@property
[docs]
def overlays(self) -> Dict[str, Any]:
"""Returns the overlays."""
return self._overlays
@property
[docs]
def cumulative_reward(self) -> float:
"""Returns the cumulative reward."""
return self._cumulative_reward
@property
[docs]
def stashed_cumulative_reward(self) -> float:
"""Returns the previous cumulative reward."""
return self._stashed_cumulative_reward
@property
[docs]
def num_agents(self) -> int:
"""Returns the number of agents in the environment.
This is part of the PettingZoo API.
"""
return len(self.agents)
@property
[docs]
def possible_agents(self) -> List[str]:
"""Returns the possible agents in the environment.
This is part of the PettingZoo API.
Assumes that the possible agents are the same as the agents.
"""
return list(self._agents.keys())
@property
[docs]
def observation_spaces(self) -> spaces.Dict:
"""Creates the observation spaces.
This is part of the PettingZoo API.
By default, this environment will support multi-agent
observations/actions/etc. This method will create _all_ the observation
spaces for the environment. But note that stable baselines3 only supports single
agent environments (i.e. non-nested spaces.Dict), so ensure you wrap this env
with a `wrappers.MjCambrianSingleagentEnvWrapper` if you want to use stable
baselines3.
"""
# Create the observation_spaces
observation_spaces: Dict[str, spaces.Space] = {}
for name, agent in self._agents.items():
if agent.trainable:
observation_spaces[name] = agent.observation_space
return spaces.Dict(observation_spaces)
@property
[docs]
def action_spaces(self) -> spaces.Dict:
"""Creates the action spaces.
This is part of the PettingZoo API.
By default, this environment will support multi-agent
observations/actions/etc. This method will create _all_ the action
spaces for the environment. But note that stable baselines3 only supports single
agent environments (i.e. non-nested spaces.Dict), so ensure you wrap this env
with a `wrappers.MjCambrianSingleagentEnvWrapper` if you want to use stable
baselines3.
"""
# Create the action_spaces
action_spaces: Dict[str, spaces.Space] = {}
for name, agent in self._agents.items():
if agent.trainable:
action_spaces[name] = agent.action_space
return spaces.Dict(action_spaces)
[docs]
def observation_space(self, agent: str) -> spaces.Space:
"""Returns the observation space for the given agent.
This is part of the PettingZoo API.
"""
assert agent in list(
self.observation_spaces.keys()
), f"Agent {agent} not found. Available: {list(self.observation_spaces.keys())}"
return self.observation_spaces[agent]
[docs]
def action_space(self, agent: str) -> spaces.Space:
"""Returns the action space for the given agent.
This is part of the PettingZoo API.
"""
assert agent in list(
self.action_spaces.keys()
), f"Agent {agent} not found. Available: {list(self.action_spaces.keys())}"
return self.action_spaces[agent]
[docs]
def state(self) -> np.ndarray:
"""Returns the state of the environment.
This is part of the PettingZoo API.
"""
raise NotImplementedError("Not implemented yet.")
[docs]
def set_random_seed(self, seed: int | float | None):
"""Sets the seed for the environment."""
from stable_baselines3.common.utils import set_random_seed
if seed is None:
return
get_logger().info(f"Setting random seed to {seed}")
set_random_seed(seed)
[docs]
def record(self, record: bool = True, *, path: Optional[Path] = None) -> bool:
"""Sets whether the environment is recording."""
self._record = record
self._renderer.record(record, path=path)
if not self._record:
self._rollout.clear()
[docs]
def save(self, path: str | Path, *, save_pkl: bool = False, **kwargs):
"""Saves the simulation output to the given path."""
self._renderer.save(path, **kwargs)
if save_pkl:
get_logger().info(f"Saving rollout to {path.with_suffix('.pkl')}")
with open(path.with_suffix(".pkl"), "wb") as f:
pickle.dump(self._rollout, f)
get_logger().debug(f"Saved rollout to {path.with_suffix('.pkl')}")
[docs]
def close(self):
"""Closes the environment."""
pass
if __name__ == "__main__":
import argparse
from cambrian import MjCambrianConfig, run_hydra
REGISTRY = {}
def register_fn(fn: Callable):
REGISTRY[fn.__name__] = fn
return fn
@register_fn
def run_mj_viewer(config: MjCambrianConfig, **__):
import mujoco.viewer
env = config.env.instance(config.env)
env.reset(seed=config.seed)
with mujoco.viewer.launch_passive(env.model, env.data) as viewer:
while viewer.is_running():
# env.step(env.action_spaces.sample())
viewer.sync()
@register_fn
def run_renderer(config: MjCambrianConfig, *, record: bool, no_step: bool, **__):
config.save(config.expdir / "config.yaml")
env = config.env.instance(config.eval_env)
env.record(record, path=config.expdir)
env.reset(seed=config.seed)
env.spec.save(config.expdir / "env.xml")
action = {name: [-1.0, -0.0] for name, a in env.agents.items() if a.trainable}
env.step(action.copy())
if "human" in config.env.renderer.render_modes:
import glfw
def custom_key_callback(_, key, *args, **__):
if key == glfw.KEY_R:
env.reset()
elif key == glfw.KEY_UP:
name = next(iter(action.keys()))
action[name][0] += 0.001
action[name][0] = min(1.0, action[name][0])
elif key == glfw.KEY_DOWN:
name = next(iter(action.keys()))
action[name][0] -= 0.001
action[name][0] = max(-1.0, action[name][0])
elif key == glfw.KEY_LEFT:
name = next(iter(action.keys()))
action[name][1] -= 0.01
action[name][1] = max(-1.0, action[name][1])
elif key == glfw.KEY_RIGHT:
name = next(iter(action.keys()))
action[name][1] += 0.01
action[name][1] = min(1.0, action[name][1])
elif key == glfw.KEY_S:
get_logger().info(f"Saving env to {config.expdir / 'env.xml'}")
mj.mj_saveLastXML(str(config.expdir / "env.xml"), env.model)
env.renderer.viewer.custom_key_callback = custom_key_callback
while env.renderer.is_running():
if env.episode_step > env.max_episode_steps:
break
if not no_step:
env.step(action.copy())
env.render()
if record:
for name, agent in env.agents.items():
if not agent.trainable:
continue
if record:
env.save(
config.expdir / "eval",
save_pkl=False,
save_mode=MjCambrianRendererSaveMode.MP4
| MjCambrianRendererSaveMode.GIF
| MjCambrianRendererSaveMode.PNG,
)
def main(config: MjCambrianConfig, *, fn: str, **kwargs):
if fn not in REGISTRY:
raise ValueError(f"Unknown function {fn}")
REGISTRY[fn](config, **kwargs)
parser = argparse.ArgumentParser()
parser.add_argument(
"fn", type=str, help="The method to run.", choices=REGISTRY.keys()
)
parser.add_argument(
"--record",
action="store_true",
help="Record the simulation.",
default=False,
)
parser.add_argument(
"--no-step",
action="store_true",
help="Don't step the environment. Useful for debugging.",
default=False,
)
run_hydra(main, parser=parser)