import contextlib
import pickle
from dataclasses import dataclass
from fnmatch import fnmatch
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
)
import mujoco as mj
import numpy as np
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecEnv
from cambrian.utils.logger import get_logger
if TYPE_CHECKING:
from cambrian.agents.agent import MjCambrianAgent
from cambrian.envs.env import MjCambrianEnv
from cambrian.ml.model import MjCambrianModel
# ============
device = get_device("auto")
# ============
[docs]
def evaluate_policy(
env: VecEnv,
model: "MjCambrianModel",
num_runs: int,
*,
record_kwargs: Optional[Dict[str, Any]] = None,
step_callback: Optional[Callable[["MjCambrianEnv"], bool | None]] = lambda _: True,
done_callback: Optional[Callable[[int], bool | None]] = lambda _: True,
) -> float:
"""Evaluate a policy.
Args:
env (gym.Env): The environment to evaluate the policy on. Assumed to be a
VecEnv wrapper around a MjCambrianEnv.
model (MjCambrianModel): The model to evaluate.
num_runs (int): The number of runs to evaluate the policy on.
Keyword Args:
record_kwargs (Dict[str, Any]): The keyword arguments to pass to the save
method of the environment. If None, the environment will not be recorded.
step_callback (Callable[[], bool]): The callback function to call at each step.
If the function returns False, the evaluation will stop.
done_callback (Callable[[int], bool]): The callback function to call when a run
is done. If the function returns False, the evaluation will stop.
Returns:
float: The cumulative reward of the evaluation.
"""
# To avoid circular imports
from cambrian.envs import MjCambrianEnv
from cambrian.utils.logger import get_logger
cambrian_env: MjCambrianEnv = env.envs[0].unwrapped
if record_kwargs is not None:
# don't set to `record_path is not None` directly bc this will delete overlays
cambrian_env.record()
run = 0
obs = env.reset()
get_logger().info(f"Starting {num_runs} evaluation run(s)...")
while run < num_runs:
# get number of parameters
action, _ = model.predict(obs, deterministic=True)
obs, _, done, _ = env.step(action)
if done:
get_logger().info(
f"Run {run} done. "
f"Cumulative reward: {cambrian_env.stashed_cumulative_reward}"
)
if done_callback(run) is False:
break
run += 1
if step_callback(cambrian_env) is False:
break
if record_kwargs is not None:
env.render()
if record_kwargs is not None:
cambrian_env.save(**record_kwargs)
cambrian_env.record(False)
return cambrian_env.stashed_cumulative_reward
def moving_average(values, window, mode="valid"):
weights = np.repeat(1.0, window) / window
return np.convolve(values, weights, mode=mode)
# =============
[docs]
def save_data(data: Any, outdir: Path, pickle_file: Path):
"""Save the parsed data to a pickle file."""
pickle_file = (outdir / pickle_file).resolve()
pickle_file.parent.mkdir(parents=True, exist_ok=True)
with open(pickle_file, "wb") as f:
pickle.dump(data, f)
get_logger().info(f"Saved parsed data to {pickle_file}.")
[docs]
def try_load_pickle(folder: Path, pickle_file: Path) -> Any | None:
"""Try to load the data from the pickle file."""
pickle_file = (folder / pickle_file).resolve()
if pickle_file.exists():
get_logger().info(f"Loading parsed data from {pickle_file}...")
with open(pickle_file, "rb") as f:
data = pickle.load(f)
get_logger().info(f"Loaded parsed data from {pickle_file}.")
return data
get_logger().warning(f"Could not load {pickle_file}.")
return None
# =============
[docs]
def generate_sequence_from_range(
range: Tuple[float, float], num: int, endpoint: bool = True
) -> List[float]:
"""Generate a sequence of numbers from a range. If num is 1, the average of the
range is returned. Otherwise, a sequence of numbers is generated using np.linspace.
Args:
range (Tuple[float, float]): The range of the sequence.
num (int): The number of elements in the sequence.
Keyword Args:
endpoint (bool): Whether to include the endpoint in the sequence.
"""
sequence = (
[np.average(range)] if num == 1 else np.linspace(*range, num, endpoint=endpoint)
)
return [float(x) for x in sequence]
[docs]
@contextlib.contextmanager
def setattrs_temporary(
*args: Tuple[Any, Dict[str, Any]]
) -> Generator[None, None, None]:
"""Temporarily set attributes of an object."""
prev_values = []
for obj, kwargs in args:
prev_values.append({})
for attr, value in kwargs.items():
if isinstance(obj, dict):
prev_values[-1][attr] = obj[attr]
obj[attr] = value
else:
prev_values[-1][attr] = getattr(obj, attr)
setattr(obj, attr, value)
try:
yield
finally:
for (obj, _), kwargs in zip(args, prev_values):
for attr, value in kwargs.items():
if isinstance(obj, dict):
obj[attr] = value
else:
setattr(obj, attr, value)
def is_number(maybe_num: Any) -> bool:
from numbers import Number
return isinstance(maybe_num, Number)
def is_integer(maybe_int: Any) -> bool:
if isinstance(maybe_int, int):
return True
if isinstance(maybe_int, str):
return maybe_int.isdigit() or (
maybe_int[1:].isdigit() if maybe_int[0] == "-" else False
)
if isinstance(maybe_int, np.ndarray):
return np.all(np.mod(maybe_int, 1) == 0)
return False
[docs]
def make_odd(num: int | float) -> int:
"""Make a number odd by adding 1 if it is even. If `num` is a float, it is cast to
an int."""
return int(num) if num % 2 == 1 else int(num) + 1
[docs]
def round_half_up(n: float) -> int:
"""Round a number to the nearest integer, rounding half up."""
return int(np.floor(n + 0.5))
[docs]
def safe_index(
arr: List[Any], value: Any, *, default: Optional[int] = None
) -> int | None:
"""Safely get the index of a value in a list. If the value is not in the list, None
is returned."""
try:
return arr.index(value)
except ValueError:
return default
@contextlib.contextmanager
def suppress_stdout_stderr():
import os
import sys
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
# =============
# Mujoco utils
[docs]
@dataclass
class MjCambrianActuator:
"""Helper class which stores information about a Mujoco actuator.
Attributes:
adr (int): The Mujoco actuator ID (index into model.actuator_* arrays).
trnadr (int): The index of the actuator's transmission in the model.
ctrlrange (Tuple[float, float]): The control range of the actuator.
ctrllimited (bool): Whether the actuator is control-limited.
"""
adr: int
trnadr: int
ctrlrange: Tuple[float, float]
ctrllimited: bool
[docs]
@dataclass
class MjCambrianJoint:
"""Helper class which stores information about a Mujoco joint.
Attributes:
type (int): The Mujoco joint type (mj.mjtJoint).
adr (int): The Mujoco joint ID (index into model.jnt_* arrays).
qposadr (int): The index of the joint's position in the qpos array.
numqpos (int): The number of positions in the joint.
qveladr (int): The index of the joint's velocity in the qvel array.
numqvel (int): The number of velocities in the joint.
"""
type: int
adr: int
qposadr: int
numqpos: int
qveladr: int
numqvel: int
[docs]
@staticmethod
def create(model: mj.MjModel, jntadr: int) -> "MjCambrianJoint":
"""Create a Joint object from a Mujoco model and joint body ID."""
qposadr = model.jnt_qposadr[jntadr]
qveladr = model.jnt_dofadr[jntadr]
jnt_type = model.jnt_type[jntadr]
if jnt_type == mj.mjtJoint.mjJNT_FREE:
numqpos = 7
numqvel = 6
elif jnt_type == mj.mjtJoint.mjJNT_BALL:
numqpos = 4
numqvel = 3
else: # mj.mjtJoint.mjJNT_HINGE or mj.mjtJoint.mjJNT_SLIDE
numqpos = 1
numqvel = 1
return MjCambrianJoint(jnt_type, jntadr, qposadr, numqpos, qveladr, numqvel)
@property
def qposadrs(self) -> List[int]:
"""Get the indices of the joint's positions in the qpos array."""
return list(range(self.qposadr, self.qposadr + self.numqpos))
@property
def qveladrs(self) -> List[int]:
"""Get the indices of the joint's velocities in the qvel array."""
return list(range(self.qveladr, self.qveladr + self.numqvel))
[docs]
@dataclass
class MjCambrianGeometry:
"""Helper class which stores information about a Mujoco geometry
Attributes:
id (int): The Mujoco geometry ID (index into model.geom_* arrays).
rbound (float): The radius of the geometry's bounding sphere.
pos (np.ndarray): The position of the geometry relative to the body.
"""
id: int
rbound: float
pos: np.ndarray
[docs]
def pickle_unpickleable_object(source: Any):
"""Return a reduce tuple: (unpickle_callable, (args,))"""
attributes = {
attr: getattr(source, attr)
for attr in dir(source)
if not callable(getattr(source, attr)) and not attr.startswith("__")
}
# Return *our* unpickle function + arguments
return unpickle_unpickleable_object, (source.__class__, attributes)
[docs]
def unpickle_unpickleable_object(cls: Type[Any], attributes: Dict[str, Any]) -> Any:
"""This will be called automatically by Python when unpickling."""
obj = cls() # MjvOption() with no args
for attr, value in attributes.items():
setattr(obj, attr, value)
return obj
def register_pickle_unpickleable_object(cls: Type[Any]):
import copyreg
copyreg.pickle(
cls,
pickle_unpickleable_object,
lambda source: unpickle_unpickleable_object(cls, source),
)
register_pickle_unpickleable_object(mj.MjvOption)
register_pickle_unpickleable_object(mj.MjvCamera)
# ============
[docs]
def agent_selected(agent: "MjCambrianAgent", agents: Optional[List[str]]):
"""Check if the agent is selected."""
return agents is None or any(fnmatch(agent.name, pattern) for pattern in agents)