Source code for cambrian.renderer.overlays

"""Defines utilities for overlays in the Mujoco viewer."""

from dataclasses import dataclass, replace
from enum import Enum
from typing import Any, Optional, Tuple

import mujoco as mj
import numpy as np
import torch

from cambrian.renderer.render_utils import resize_with_aspect_fill
from cambrian.utils.logger import get_logger


class MjCambrianCursor:
    class Position(Enum):
        TOP_LEFT = 1
        TOP_RIGHT = 2
        BOTTOM_LEFT = 3
        BOTTOM_RIGHT = 4

    class Layer(Enum):
        BACK = 1
        AUTO = 2
        FRONT = 3

    container_width: int | None = None
    container_height: int | None = None
    position: Position = Position.TOP_LEFT
    layer: Layer = Layer.AUTO
    margin: int = TEXT_MARGIN
    x: int | None = None
    y: int | None = None

    def __post_init__(self):
        if self.container_width is not None and self.container_height is not None:
            assert (
                self.x is not None and self.y is not None
            ), "Either container width and height or x and y must be provided."

    def _set_position(self):
        if self.position == MjCambrianCursor.Position.TOP_LEFT:
            x = self.margin
            y = self.container_height  # + self.margin + TEXT_HEIGHT
        elif self.position == MjCambrianCursor.Position.TOP_RIGHT:
            x = self.container_width - self.margin
            y = self.container_height - self.margin - TEXT_HEIGHT
        elif self.position == MjCambrianCursor.Position.BOTTOM_LEFT:
            x = self.margin
            y = self.margin - TEXT_HEIGHT
        elif self.position == MjCambrianCursor.Position.BOTTOM_RIGHT:
            x = self.container_width - self.margin
            y = self.margin

        if self.x is None:
            self.x = x
        if self.y is None:
            self.y = y

    def move(self, dx: int, dy: int) -> "MjCambrianCursor":
        self.x += dx
        self.y += dy
        return self

    def copy(self) -> "MjCambrianCursor":
        return replace(self)

[docs] class MjCambrianViewerOverlay: """This class is used to add an overlay to the viewer. Note: This is applied only to the passed scene, so other scenes (i.e. ones for the eyes) will not be affected. """ def __init__(self, obj: Any, cursor: Optional[MjCambrianCursor] = None): self._obj = obj self._cursor = cursor.copy() if cursor is not None else None
[docs] def draw_before_render(self, scene: mj.MjvScene): """Called before rendering the scene.""" pass
[docs] def draw_after_render(self, mjr_context: mj.MjrContext, viewport: mj.MjrRect): """Called after rendering the scene.""" pass
[docs] def place(self, cursor: MjCambrianCursor) -> MjCambrianCursor: """Places the cursor at the given location. Won't overwrite the current cursor, if it exists. Args: cursor: The cursor to place. """ if self._cursor is None: self._cursor = cursor.copy() return cursor
@property def layer(self) -> MjCambrianCursor.Layer: return ( self._cursor.layer if self._cursor is not None else MjCambrianCursor.Layer.AUTO ) # ============= @staticmethod
[docs] def create_text_overlay( text: str, *, cursor: MjCambrianCursor | None = None, ): """Creates a text overlay.""" return MjCambrianTextViewerOverlay(text, cursor)
[docs] def create_image_overlay( obj: torch.Tensor, *, cursor: MjCambrianCursor | None = None ): """Creates an image overlay.""" return MjCambrianImageViewerOverlay(obj, cursor)
[docs] def create_site_overlay( pos: np.ndarray, rgba: Tuple[float, float, float, float], size: float, geom_kwargs: dict = dict(emission=0.25), ): """Creates a site overlay.""" return MjCambrianSiteViewerOverlay(pos, rgba, size, geom_kwargs)
[docs] class MjCambrianTextViewerOverlay(MjCambrianViewerOverlay): """This class is used to add text to the viewer."""
[docs] def draw_after_render(self, mjr_context: mj.MjrContext, viewport: mj.MjrRect): viewport = ( viewport if self._cursor is None else mj.MjrRect(self._cursor.x, self._cursor.y, 1, 1) ) mj.mjr_overlay( mj.mjtFont.mjFONT_NORMAL, mj.mjtGridPos.mjGRID_BOTTOMLEFT, viewport, self._obj, "", mjr_context, )
[docs] def place(self, cursor: MjCambrianCursor) -> MjCambrianCursor: if cursor.position in { MjCambrianCursor.Position.BOTTOM_LEFT, MjCambrianCursor.Position.BOTTOM_RIGHT, }: cursor.move(0, TEXT_HEIGHT + TEXT_MARGIN) else: cursor.move(0, -(TEXT_HEIGHT + TEXT_MARGIN)) return super().place(cursor)
[docs] class MjCambrianImageViewerOverlay(MjCambrianViewerOverlay): """This class is used to add an image to the viewer.""" def __init__(self, obj: torch.Tensor, cursor: Optional[MjCambrianCursor] = None): super().__init__(obj, cursor) self._obj_cpu = obj.cpu()
[docs] def draw_after_render(self, mjr_context: mj.MjrContext, viewport: mj.MjrRect): assert self._cursor is not None self._obj_cpu.copy_(self._obj, non_blocking=True) viewport = mj.MjrRect( self._cursor.x, self._cursor.y, self._obj.shape[1], self._obj.shape[0] ) mj.mjr_drawPixels(self._obj_cpu.numpy().ravel(), None, viewport, mjr_context)
[docs] def place(self, cursor: MjCambrianCursor) -> MjCambrianCursor: # Just resize the image to fit the container # We won't change the cursor, so it will be placed at the same location self._obj = resize_with_aspect_fill( self._obj, cursor.container_height, cursor.container_width ) self._obj_cpu = self._obj.cpu() return super().place(cursor)
[docs] class MjCambrianSiteViewerOverlay(MjCambrianViewerOverlay): """This class is used to add a site to the viewer. Todo: Make this an image overlay where the pos is converted to pixel coordinates. """ def __init__( self, pos: np.ndarray, rgba: Tuple[float, float, float, float], size: float, geom_kwargs: dict = dict(emission=0.25), ): super().__init__(pos) self._rgba = rgba self._size = size self._geom_kwargs = geom_kwargs
[docs] def draw_before_render(self, scene: mj.MjvScene): if scene.ngeom >= scene.maxgeom: get_logger().warning( f"Max geom reached ({scene.maxgeom}). Cannot add more sites." ) return scene.ngeom += 1 mj.mjv_initGeom( scene.geoms[scene.ngeom - 1], mj.mjtGeom.mjGEOM_SPHERE, [self._size] * 3, self._obj, np.eye(3).flatten(), self._rgba, ) for key, value in self._geom_kwargs.items(): setattr(scene.geoms[scene.ngeom - 1], key, value)