Source code for cambrian.envs.step_fns

"""Step fns. These can be used to modify the observation and info dictionaries."""

from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from cambrian.agents import MjCambrianAgent
from cambrian.envs import MjCambrianEnv, MjCambrianMazeEnv

# ======================
# Helpers


[docs] def respawn_agent( env: MjCambrianMazeEnv, agent: MjCambrianAgent, ) -> Dict[str, Any]: """Respawn agent at given position.""" agent.init_pos = env.maze.generate_reset_pos(agent.name) obs = agent.reset(env.spec) return obs
# ====================== # Step Functions
[docs] def step_respawn_agents_if_close_to_agents( env: MjCambrianMazeEnv, obs: Dict[str, Any], info: Dict[str, Dict[str, Any]], *, distance_threshold: float, for_agents: Optional[List[str]] = None, to_agents: Optional[List[str]] = None, from_agents: Optional[List[str]] = None, ): """ Keywords Args: for_agents: List of agent names to check for proximity. to_agents: List of agent names to check distance to from_agents: List of agent names to check distance from """ for agent_name, agent in env.agents.items(): if for_agents is not None and agent_name not in for_agents: continue if from_agents is not None and agent_name not in from_agents: continue for other_agent_name, other_agent in env.agents.items(): if to_agents is not None and other_agent_name not in to_agents: continue if agent_name == other_agent_name: continue info[agent_name]["respawned"] = False if np.linalg.norm(agent.pos - other_agent.pos) < distance_threshold: obs[agent_name] = respawn_agent(env, agent) info[agent_name]["respawned"] = True return obs, info
[docs] def step_add_agent_qpos_to_info( env: MjCambrianEnv, obs: Dict[str, Any], info: Dict[str, Dict[str, Any]], *, for_agents: Optional[List[str]] = None, ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]: """Add agent qpos to observation.""" for agent_name, agent in env.agents.items(): if for_agents is not None and agent_name not in for_agents: continue info[agent_name]["qpos"] = agent.qpos return obs, info
[docs] def step_combined( env: MjCambrianEnv, obs: Dict[str, Any], info: Dict[str, Dict[str, Any]], **step_fns, ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]: """Combine multiple step functions.""" for step_fn in step_fns.values(): obs, info = step_fn(env, obs, info) return obs, info