# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy
from typing import Callable, Dict, List, Optional
import torch
from tensordict import TensorDictBase
from torchrl.data import Composite
from torchrl.envs import (
DoubleToFloat,
DTypeCastTransform,
EnvBase,
FlattenObservation,
Transform,
)
from benchmarl.environments.common import Task, TaskClass
from benchmarl.utils import DEVICE_TYPING
class MeltingPotClass(TaskClass):
def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
from torchrl.envs.libs.meltingpot import MeltingpotEnv
config = copy.deepcopy(self.config)
return lambda: MeltingpotEnv(
substrate=self.name.lower(),
categorical_actions=True,
device=device,
**config,
)
def supports_continuous_actions(self) -> bool:
return False
def supports_discrete_actions(self) -> bool:
return True
def has_render(self, env: EnvBase) -> bool:
return True
def max_steps(self, env: EnvBase) -> int:
return self.config.get("max_steps", 100)
def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map
def get_env_transforms(self, env: EnvBase) -> List[Transform]:
interaction_inventories_keys = [
(group, "observation", "INTERACTION_INVENTORIES")
for group in self.group_map(env).keys()
if (group, "observation", "INTERACTION_INVENTORIES")
in env.observation_spec.keys(True, True)
]
return [DoubleToFloat()] + (
[
FlattenObservation(
in_keys=interaction_inventories_keys,
first_dim=-2,
last_dim=-1,
)
]
if len(interaction_inventories_keys)
else []
)
def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]:
return [
DTypeCastTransform(
dtype_in=torch.uint8,
dtype_out=torch.float,
in_keys=[
"RGB",
(group, "observation", "RGB"),
("next", "RGB"),
("next", group, "observation", "RGB"),
],
in_keys_inv=[],
)
]
def state_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
del observation_spec[group]
if list(observation_spec.keys()) != ["RGB"]:
raise ValueError(
f"More than one global state key found in observation spec {observation_spec}."
)
return observation_spec
def action_mask_spec(self, env: EnvBase) -> Optional[Composite]:
return None
def observation_spec(self, env: EnvBase) -> Composite:
observation_spec = env.observation_spec.clone()
for group_key in list(observation_spec.keys()):
if group_key not in self.group_map(env).keys():
del observation_spec[group_key]
return observation_spec
def info_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group_key in list(observation_spec.keys()):
if group_key not in self.group_map(env).keys():
del observation_spec[group_key]
else:
group_obs_spec = observation_spec[group_key]["observation"]
del group_obs_spec["RGB"]
return observation_spec
def action_spec(self, env: EnvBase) -> Composite:
return env.full_action_spec
@staticmethod
def env_name() -> str:
return "meltingpot"
@staticmethod
def render_callback(experiment, env: EnvBase, data: TensorDictBase):
return data.get("RGB")
[docs]
class MeltingPotTask(Task):
"""Enum for meltingpot tasks."""
PREDATOR_PREY__ALLEY_HUNT = None
CLEAN_UP = None
COLLABORATIVE_COOKING__CIRCUIT = None
FRUIT_MARKET__CONCENTRIC_RIVERS = None
COLLABORATIVE_COOKING__FIGURE_EIGHT = None
PAINTBALL__KING_OF_THE_HILL = None
FACTORY_COMMONS__EITHER_OR = None
PURE_COORDINATION_IN_THE_MATRIX__ARENA = None
RUNNING_WITH_SCISSORS_IN_THE_MATRIX__REPEATED = None
COLLABORATIVE_COOKING__CRAMPED = None
RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ARENA = None
PRISONERS_DILEMMA_IN_THE_MATRIX__REPEATED = None
TERRITORY__OPEN = None
STAG_HUNT_IN_THE_MATRIX__REPEATED = None
CHICKEN_IN_THE_MATRIX__REPEATED = None
GIFT_REFINEMENTS = None
PURE_COORDINATION_IN_THE_MATRIX__REPEATED = None
COLLABORATIVE_COOKING__FORCED = None
RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__ARENA = None
BACH_OR_STRAVINSKY_IN_THE_MATRIX__ARENA = None
CHEMISTRY__TWO_METABOLIC_CYCLES_WITH_DISTRACTORS = None
COMMONS_HARVEST__PARTNERSHIP = None
PREDATOR_PREY__OPEN = None
TERRITORY__ROOMS = None
HIDDEN_AGENDA = None
COOP_MINING = None
DAYCARE = None
PRISONERS_DILEMMA_IN_THE_MATRIX__ARENA = None
TERRITORY__INSIDE_OUT = None
BACH_OR_STRAVINSKY_IN_THE_MATRIX__REPEATED = None
COMMONS_HARVEST__CLOSED = None
CHEMISTRY__THREE_METABOLIC_CYCLES_WITH_PLENTIFUL_DISTRACTORS = None
STAG_HUNT_IN_THE_MATRIX__ARENA = None
PAINTBALL__CAPTURE_THE_FLAG = None
COLLABORATIVE_COOKING__CROWDED = None
ALLELOPATHIC_HARVEST__OPEN = None
COLLABORATIVE_COOKING__RING = None
COMMONS_HARVEST__OPEN = None
COINS = None
PREDATOR_PREY__ORCHARD = None
PREDATOR_PREY__RANDOM_FOREST = None
COLLABORATIVE_COOKING__ASYMMETRIC = None
RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__REPEATED = None
CHEMISTRY__THREE_METABOLIC_CYCLES = None
RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ONE_SHOT = None
CHEMISTRY__TWO_METABOLIC_CYCLES = None
CHICKEN_IN_THE_MATRIX__ARENA = None
BOAT_RACE__EIGHT_RACES = None
EXTERNALITY_MUSHROOMS__DENSE = None
[docs]
@staticmethod
def associated_class():
return MeltingPotClass