Source code for cambrian.agents.agent

"""Defines agent classes."""

from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Self, Tuple

import mujoco as mj
import numpy as np
from gymnasium import spaces
from hydra_config import HydraContainerConfig, config_wrapper

from cambrian.eyes.eye import MjCambrianEye, MjCambrianEyeConfig
from cambrian.renderer.render_utils import generate_composite, resize_with_aspect_fill
from cambrian.utils import (
    MjCambrianActuator,
    MjCambrianGeometry,
    MjCambrianJoint,
    get_logger,
)
from cambrian.utils.cambrian_xml import MjCambrianXML, MjCambrianXMLConfig
from cambrian.utils.spec import MjCambrianSpec, spec_from_xml_string
from cambrian.utils.types import ActionType, ObsType, RenderFrame

if TYPE_CHECKING:
    from cambrian.envs import MjCambrianEnv


@config_wrapper
[docs] class MjCambrianAgentConfig(HydraContainerConfig): """Defines the config for an agent. Used for type hinting. Attributes: instance (Callable[[Self, str, int], "MjCambrianAgent"]): The class instance for the agent. This is used to create the agent. Takes the config, the name, and the index of the agent as arguments. trainable (bool): Whether the agent is trainable or not. If the agent is trainable, it's observations will be included in the observation space of the environment and the model's output actions will be applied to the agent. If the agent is not trainable, the agent's policy can be defined by overriding the `get_action_privileged` method. use_privileged_action (bool): This is similar to `trainable`, but the agent's action and observation spaces is included in the environment's action and observation spaces, respectively. This is useful for agents that are trainable, but have some special logic that needs to be implemented in the `get_action_privileged` method. `trainable` takes precedence over this attribute, as in if `trainable` is False, this attribute is ignored. overlay_color (Tuple[float, float, float, float]): The color to use in the visualization of the agent. overlay_size (float): The size of the overlay in the visualization of the agent. xml (MjCambrianXMLConfig): The xml for the agent. This is the xml that will be used to create the agent. You should use ${parent:xml} to generate named attributes. This will search upwards in the yaml file to find the name of the agent. body_name (str): The name of the body that defines the main body of the agent. geom_name (str): The name of the geom that are used for eye placement. check_contacts (bool): Whether to check contacts or not. If this is True, then the contacts will be checked in the environment. init_pos (Tuple[float | None]): The initial position of the agent. Specific indices of the position are set when not None. The length of the tuple should be <= 3. None's are filled in at the end if the length is less than 3. init_quat (Tuple[float | None]): The initial quaternion of the agent. Specific indices of the quaternion are set when not None. The length of the tuple should be <= 4. None's are filled in at the end if the length is less than 4. perturb_init_pos (bool): Whether to perturb the initial pos of the agent or not. If this is True, then the initial pos of the agent will be randomly adjusted based on a normal distribution. use_action_obs (bool): Whether to use the action observation or not. NOTE: If the MjCambrianConstantActionWrapper is used, this is not reflected in the observation, as in the actions will vary in the observation. use_contact_obs (bool): Whether to use the contact observation or not. If this is True, then the contacts will be included in the observation space of the agent. eyes (Dict[str, MjCambrianEyeConfig]): The eyes on the agent. The keys are the names of the eyes and the values are the configs for the eyes. The eyes will be placed on the agent at the specified coordinates. """ instance: Callable[[Self, str, int], "MjCambrianAgent"] trainable: bool use_privileged_action: bool overlay_color: Tuple[float, float, float, float] overlay_size: float xml: MjCambrianXMLConfig body_name: str geom_name: str check_contacts: bool init_pos: Tuple[float | None] init_quat: Tuple[float | None] perturb_init_pos: bool use_action_obs: bool use_contact_obs: bool eyes: Dict[str, MjCambrianEyeConfig]
[docs] class MjCambrianAgent: """The agent class is defined as a physics object with eyes. This object serves as an agent in a multi-agent mujoco environment. Therefore, it must have a uniquely identifiable name. In our context, an agent has at any number of eyes and a body which an eye can be attached to. This class abstracts away the inner workings of the mujoco model itself to the xml generation/loading. It uses existing xml files that only define the agents, with which are going to be accumulating into one large xml file that will be loaded into mujoco. To support specific agent types, you should define subclasses that include agent specific configs (i.e. model_path, num_joints). Args: config (MjCambrianAgentConfig): The configuration for the agent. name (str): The name of the agent. This is used to identify the agent in the environment. """ def __init__(self, config: MjCambrianAgentConfig, name: str): self._config = config self._name = name self._eyes: Dict[str, MjCambrianEye] = {} self._spec: MjCambrianSpec = None self._init_pos: Tuple[float | None, float | None, float | None] = None self._init_quat: Tuple[ float | None, float | None, float | None, float | None ] = None self._last_action: List[float] = None self._actuators: List[MjCambrianActuator] = [] self._joints: List[MjCambrianJoint] = [] self._geom: MjCambrianGeometry = None self._body: mj.MjsBody = None self._actadrs: List[int] = [] self._body_id: int = None self._initialize() def _initialize(self): """Initialize the agent. This method does the following: - load the base xml to MjModel - parse the geometry - place eyes at the appropriate locations """ try: spec = spec_from_xml_string(self._config.xml) except Exception: get_logger().error(f"Error creating model\n{self._config.xml}") raise self._parse_geometry(spec) self._parse_actuators(spec) self._create_eyes() assert len(self._config.init_pos) == 3, "init_pos must have 3 elements." self._init_pos = self._config.init_pos assert len(self._config.init_quat) == 4, "init_quat must have 4 elements." self._init_quat = self._config.init_quat def _parse_geometry(self, spec: MjCambrianSpec): """Parse the geometry to get the root body, number of controls, joints, and actuators. We're going to do some preprocessing of the model here to get info regarding num joints, num controls, etc. This is because we need to know this to compute the observation and action spaces, which is needed _before_ actually initializing mujoco (but we can't get this information until after mujoco is initialized and we don't want to hardcode this for extensibility). Note: We can't grab the ids/adrs here because they'll be different once we load the entire model """ model = spec.model # Num of controls if self.trainable: assert model.nu > 0, "Trainable agents must have controllable actuators." # Get number of qpos/qvel/ctrl # Just stored for later, like to get the observation space, etc. self._numqpos = model.nq self._numctrl = model.nu # Create the geometries we will use for eye placement geom_id = spec.get_geom_id(self._config.geom_name) assert geom_id != -1, f"Could not find geom {self._config.geom_name}." geom_rbound = model.geom_rbound[geom_id] geom_pos = model.geom_pos[geom_id] self._geom = MjCambrianGeometry(geom_id, geom_rbound, geom_pos) def _parse_actuators(self, spec: MjCambrianSpec): """Parse the current model/xml for the actuators. We have to do this twice: once on the initial model load to get the ctrl limits on the actuators (for the obs space). And then later to acquire the actual ids/adrs. """ model = spec.model # Root body for the agent body_name = self._config.body_name body_id = spec.get_body_id(body_name) assert body_id != -1, f"Could not find body with name {body_name}." # Mujoco doesn't have a neat way to grab the actuators associated with a # specific agent/body, so we'll try to grab them dynamically by checking the # transmission ids (the joint/site adrs associated with that actuator) and # seeing if that the corresponding root body is on this agent's body. self._actuators: List[MjCambrianActuator] = [] for actadr, ((trnid, _), trntype) in enumerate( zip(model.actuator_trnid, model.actuator_trntype) ): # Grab the body id associated with the actuator if ( trntype == mj.mjtTrn.mjTRN_JOINT or trntype == mj.mjtTrn.mjTRN_JOINTINPARENT ): act_bodyid = model.jnt_bodyid[trnid] elif trntype == mj.mjtTrn.mjTRN_SITE: act_bodyid = model.site_bodyid[trnid] elif trntype == mj.mjtTrn.mjTRN_BODY: act_bodyid = trnid elif trntype == mj.mjtTrn.mjTRN_TENDON: act_bodyid = model.tendon_adr[trnid] else: raise NotImplementedError(f'Unsupported trntype "{trntype}".') # Add the actuator if the rootbody of the actuator is the same body # as the transimission's body act_rootbodyid = model.body_rootid[act_bodyid] if act_rootbodyid == body_id: ctrlrange = model.actuator_ctrlrange[actadr] ctrllimited = model.actuator_ctrllimited[actadr] self._actuators.append( MjCambrianActuator(actadr, trnid, ctrlrange, ctrllimited) ) # Get the joints # We use the joints to get the qpos/qvel as observations (joint specific states) self._joints: List[MjCambrianJoint] = [] for jntadr, jnt_bodyid in enumerate(model.jnt_bodyid): jnt_rootbodyid = model.body_rootid[jnt_bodyid] if jnt_rootbodyid == body_id: # This joint is associated with this agent's body self._joints.append(MjCambrianJoint.create(model, jntadr)) assert ( len(self._joints) > 0 ), f"Body {body_name} has no joints. Joints are required for positioning." if self.trainable: assert len(self._actuators) > 0, f"Body {body_name} has no actuators." def _create_eyes(self): """Place the eyes on the agent.""" for name, eye_config in self._config.eyes.items(): self._eyes[name] = eye_config.instance(eye_config, f"{self._name}_{name}")
[docs] def generate_xml(self) -> MjCambrianXML: """Generates the xml for the agent. Will generate the xml from the model file and then add eyes to it. """ xml = MjCambrianXML.from_string(self._config.xml) # Add eyes for eye in self.eyes.values(): xml += eye.generate_xml(xml, self._geom, self._config.body_name) return xml
[docs] def apply_action(self, actions: ActionType): """Applies the action to the agent. This probably happens before step so that the observations reflect the state of the agent after the new action is applied. It is assumed that the actions are normalized between -1 and 1. """ self._last_action = actions.copy() if len(actions) == 0: return for action, actuator in zip(actions, self._actuators): if actuator.ctrllimited: action = np.interp(action, [-1, 1], actuator.ctrlrange) self._spec.data.ctrl[actuator.adr] = action
[docs] def get_action_privileged(self, env: "MjCambrianEnv") -> List[float]: """This is a deviation from the standard gym API. This method is similar to step, but it has "privileged" access to information such as the environment. This method can be overridden by agents which are not trainable and need to implement custom step logic. Args: env (MjCambrianEnv): The environment that the agent is in. This can be used to get information about the environment. Returns: List[float]: The action to take. """ raise NotImplementedError( "This method should be overridden by the subclass " "and should never reach here." )
[docs] def reset(self, spec: MjCambrianSpec) -> ObsType: """Sets up the agent in the environment. Uses the model/data to update positions during the simulation. """ self._spec = spec # Parse actuators; this is the second time we're doing this, but we need to # get the adrs of the actuators in the current model self._parse_actuators(spec) # Accumulate the qpos/qvel/act adrs self._reset_adrs(spec) # Reset the pose of the agent self._reset_pose(spec) # Set the last action to whatever the last ctrl was self._last_action = [] for actuator in self._actuators: action = self._spec.data.ctrl[actuator.adr] if actuator.ctrllimited: action = np.interp(action, [-1, 1], actuator.ctrlrange) self._last_action.append(action) obs: Dict[str, Any] = {} for name, eye in self.eyes.items(): eye_obs = eye.reset(spec) if isinstance(eye_obs, dict): obs.update(eye_obs) else: obs[name] = eye_obs return self._update_obs(obs)
def _reset_adrs(self, spec: MjCambrianSpec): """Resets the adrs for the agent. This is used when the model is reloaded. .. todo:: This was before the switch to spec. Is this still necessary? """ # Root body for the agent body_name = self._config.body_name self._body_id = spec.get_body_id(body_name) assert self._body_id != -1, f"Could not find body with name {body_name}." # Geometry id geom_id = spec.get_geom_id(self._config.geom_name) assert geom_id != -1, f"Could not find geom {self._config.geom_name}." self._geom.id = geom_id # Accumulate the qposadrs self._qposadrs = [] for joint in self._joints: self._qposadrs.extend(joint.qposadrs) self._actadrs: List[int] = [act.adr for act in self._actuators] assert ( len(self._qposadrs) == self._numqpos ), f"Mismatch in qpos adrs for agent '{self.name}': " f"{len(self._qposadrs)} != {self._numqpos}." assert ( len(self._actadrs) == self._numctrl ), f"Mismatch in actuator adrs for agent '{self.name}': " f"{len(self._actadrs)} != {self._numctrl}." def _reset_pose(self, spec: MjCambrianSpec): """Resets the pose of the agent.""" self.pos = self._init_pos self.quat = self._init_quat # step here so that the states are updated mj.mj_forward(spec.model, spec.data) if self._config.perturb_init_pos: pos = np.random.normal(0, self._geom.rbound / 2, 2) pos = np.clip(pos, -self._geom.rbound / 2, self._geom.rbound / 2) self.pos += [*pos, 0]
[docs] def step(self) -> ObsType: """Steps the eyes and returns the observation.""" obs: ObsType = {} for name, eye in self.eyes.items(): eye_obs = eye.step() if isinstance(eye_obs, dict): obs.update(eye_obs) else: obs[name] = eye_obs return self._update_obs(obs)
def _update_obs(self, obs: ObsType) -> ObsType: """Add additional attributes to the observation.""" if self._config.use_action_obs: obs["action"] = self._last_action if self._config.use_contact_obs: obs["contacts"] = [self.has_contacts] return obs
[docs] def render(self) -> RenderFrame: """Renders the eyes and returns the debug image. We don't know where the eyes are placed, so for simplicity, we'll just return the first eye's render. """ if len(self._eyes) == 0: return None if len(self._eyes) == 1: return next(iter(self._eyes.values())).render() # Stack the renders next to each other renders = [eye.render() for eye in self._eyes.values()] max_width, max_height = max(render.shape[1] for render in renders), max( render.shape[0] for render in renders ) renders = [ resize_with_aspect_fill(render, max_height, max_width) for render in renders ] return generate_composite({0: {i: render for i, render in enumerate(renders)}})
@property
[docs] def has_contacts(self) -> bool: """Returns whether or not the agent has contacts. Walks through all the contacts in the environment and checks if any of them involve this agent. """ if not self._config.check_contacts: return False for contact in self._spec.data.contact: if contact.exclude: continue geom1 = int(contact.geom[0]) body1 = self._spec.model.geom_bodyid[geom1] rootbody1 = self._spec.model.body_rootid[body1] geom2 = int(contact.geom[1]) body2 = self._spec.model.geom_bodyid[geom2] rootbody2 = self._spec.model.body_rootid[body2] is_this_agent = rootbody1 == self._body_id or rootbody2 == self._body_id if not is_this_agent or contact.exclude: # Not a contact with this agent continue return True return False
@cached_property
[docs] def observation_space(self) -> spaces.Space: """The observation space is defined on an agent basis. the `env` should combine the observation spaces such that it's supported by stable_baselines3/pettingzoo. The agent has three observation spaces: - {eye.name}: The eyes observations - action: The last action that was applied to the agent - contacts: Whether the agent has contacts or not """ observation_space: Dict[Any, spaces.Space] = {} for name, eye in self.eyes.items(): if isinstance(eye.observation_space, spaces.Dict): observation_space.update(eye.observation_space.spaces) else: observation_space[name] = eye.observation_space if self._config.use_action_obs: observation_space["action"] = self.action_space if self._config.use_contact_obs: observation_space["contacts"] = spaces.Box( low=0, high=1, shape=(1,), dtype=np.int32 ) return spaces.Dict(observation_space)
@cached_property
[docs] def action_space(self) -> spaces.Space: """The action space is simply the controllable actuators of the agent.""" return spaces.Box(low=-1, high=1, shape=(self._numctrl,), dtype=np.float32)
@property def config(self) -> MjCambrianAgentConfig: return self._config @property def name(self) -> str: return self._name @property def eyes(self) -> Dict[str, MjCambrianEye]: return self._eyes @property def last_action(self) -> List[float]: return self._last_action @property
[docs] def qpos(self) -> np.ndarray: """Gets the qpos of the agent. The qpos is the state of the joints defined in the agent's xml. This method is used to get the state of the qpos. It's actually a masked array where the entries are masked if the qpos adr is not associated with the agent. This allows the return value to be indexed and edited as if it were the full qpos array.""" mask = np.ones(self._spec.data.qpos.shape, dtype=bool) mask[self._qposadrs] = False masked_qpos = np.ma.masked_array(self._spec.data.qpos, mask=mask) return masked_qpos[np.flatnonzero(~masked_qpos.mask)]
@qpos.setter def qpos(self, value: np.ndarray[float | None]): """Set's the qpos of the agent. The qpos is the state of the joints defined in the agent's xml. This method is used to set the state of the qpos. The value input is a numpy array where the entries are either values to set to the corresponding qpos adr or None. It's allowed for `value` to be less than the total number of joints in the agent. If this is the case, only the first `len(value)` joints will be updated. """ for idx, val in enumerate(value): if val is not None: self._spec.data.qpos[self._qposadrs[idx]] = val @property
[docs] def pos(self) -> np.ndarray: """Returns the position of the agent in the environment. Note: the returned value, if edited, doesn't not directly impact the simulation. To set the position of the agent, use the `pos` setter. """ return self._spec.data.xpos[self._body_id].copy()
@pos.setter def pos(self, value: Tuple[float | None, float | None, float | None]): """Sets the position of the agent in the environment. The value is a tuple of the x, y, and z positions. If the value is None, the position is not updated. Note: This base implementation assumes the first 3 values of the qpos are the x, y, and z positions of the agent. This may not be the case and depends on the joints defined in the agent, so this method should be overridden in the subclass if this is not the case. """ for idx, val in enumerate(value): if val is not None: self._spec.data.qpos[self._qposadrs[idx]] = val @property
[docs] def init_pos(self) -> Tuple[float | None, float | None, float | None]: """Returns the initial position of the agent in the environment.""" assert self.body is not None, "Body is not set." return self.body.pos
@init_pos.setter def init_pos(self, value: Tuple[float | None, float | None, float | None]): """Sets the initial position of the agent in the environment. The value is a tuple of the x, y, and z positions. If the value is None, the position is not updated.""" init_pos = [None, None, None] for i in range(3): if self._config.init_pos[i] is not None: init_pos[i] = self._config.init_pos[i] elif i < len(value) and value[i] is not None: init_pos[i] = value[i] else: init_pos[i] = self._init_pos[i] self._init_pos = init_pos @property
[docs] def quat(self) -> np.ndarray: """Returns the quaternion of the agent in the environment. Fmt: `wxyz`. Note: The returned value, if edited, doesn't not directly impact the simulation. To set the quaternion of the agent, use the `quat` setter. """ return self._spec.data.xquat[self._body_id].copy()
@quat.setter def quat( self, value: Tuple[float | None, float | None, float | None, float | None] ): """Sets the quaternion of the agent in the environment. The value is a tuple of the x, y, z, and w values. If the value is None, the quaternion is not updated. Note: This base implementation assumes the 3,4,5,6 indices of the qpos are the x, y, z, and w values of the quaternion of the agent. This may not be the case and depends on the joints defined in the agent, so this method should be overridden in the subclass if this is not the case. """ for idx, val in enumerate(value): if val is not None: self._spec.data.qpos[self._qposadrs[3 + idx]] = val @property
[docs] def trainable(self) -> bool: """Returns whether the agent is trainable or not.""" return self._config.trainable
@property
[docs] def num_eyes(self) -> int: """Returns the number of eyes on the agent.""" num_eyes = 0 for eye in self.eyes.values(): num_eyes += getattr(eye, "num_eyes", 1) return num_eyes
[docs] class MjCambrianAgent2D(MjCambrianAgent): """Assumes the agent is moving on 2D plane and has a yaw hinge joint which is used to adjust orientation of the agent.""" @MjCambrianAgent.quat.setter def quat( self, value: Tuple[float | None, float | None, float | None, float | None] ): """Overrides the base implementation to set the z rotation.""" assert len(value) == 4, f"Quaternion must have 4 elements, got {len(value)}." # Only set quat if all values are not None if any(val is None for val in value): return self.qpos[self._qposadrs[2]] = np.arctan2( 2 * (value[0] * value[3] + value[1] * value[2]), 1 - 2 * (value[2] ** 2 + value[3] ** 2), )