Source code for benchmarl.environments.pettingzoo.common

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, PettingZooEnv

from benchmarl.environments.common import Task

from benchmarl.utils import DEVICE_TYPING


[docs] class PettingZooTask(Task): """Enum for PettingZoo tasks.""" MULTIWALKER = None WATERWORLD = None SIMPLE_ADVERSARY = None SIMPLE_CRYPTO = None SIMPLE_PUSH = None SIMPLE_REFERENCE = None SIMPLE_SPEAKER_LISTENER = None SIMPLE_SPREAD = None SIMPLE_TAG = None SIMPLE_WORLD_COMM = None def get_env_fun( self, num_envs: int, continuous_actions: bool, seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: if self.supports_continuous_actions() and self.supports_discrete_actions(): self.config.update({"continuous_actions": continuous_actions}) return lambda: PettingZooEnv( categorical_actions=True, device=device, seed=seed, parallel=True, return_state=self.has_state(), render_mode="rgb_array", **self.config ) def supports_continuous_actions(self) -> bool: if self in { PettingZooTask.MULTIWALKER, PettingZooTask.WATERWORLD, PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_ADVERSARY, PettingZooTask.SIMPLE_CRYPTO, PettingZooTask.SIMPLE_PUSH, PettingZooTask.SIMPLE_REFERENCE, PettingZooTask.SIMPLE_SPEAKER_LISTENER, PettingZooTask.SIMPLE_SPREAD, PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_WORLD_COMM, }: return True return False def supports_discrete_actions(self) -> bool: if self in { PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_ADVERSARY, PettingZooTask.SIMPLE_CRYPTO, PettingZooTask.SIMPLE_PUSH, PettingZooTask.SIMPLE_REFERENCE, PettingZooTask.SIMPLE_SPEAKER_LISTENER, PettingZooTask.SIMPLE_SPREAD, PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_WORLD_COMM, }: return True return False
[docs] def has_state(self) -> bool: if self in { PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_ADVERSARY, PettingZooTask.SIMPLE_CRYPTO, PettingZooTask.SIMPLE_PUSH, PettingZooTask.SIMPLE_REFERENCE, PettingZooTask.SIMPLE_SPEAKER_LISTENER, PettingZooTask.SIMPLE_SPREAD, PettingZooTask.SIMPLE_TAG, PettingZooTask.SIMPLE_WORLD_COMM, }: return True return False
def has_render(self, env: EnvBase) -> bool: return True def max_steps(self, env: EnvBase) -> int: return self.config["max_cycles"] def group_map(self, env: EnvBase) -> Dict[str, List[str]]: return env.group_map def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]: if "state" in env.observation_spec: return CompositeSpec({"state": env.observation_spec["state"].clone()}) return None def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: observation_spec = env.observation_spec.clone() for group in self.group_map(env): group_obs_spec = observation_spec[group] for key in list(group_obs_spec.keys()): if key != "action_mask": del group_obs_spec[key] if group_obs_spec.is_empty(): del observation_spec[group] if "state" in observation_spec.keys(): del observation_spec["state"] if observation_spec.is_empty(): return None return observation_spec def observation_spec(self, env: EnvBase) -> CompositeSpec: observation_spec = env.observation_spec.clone() for group in self.group_map(env): group_obs_spec = observation_spec[group] for key in list(group_obs_spec.keys()): if key != "observation": del group_obs_spec[key] return observation_spec def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: observation_spec = env.observation_spec.clone() for group in self.group_map(env): group_obs_spec = observation_spec[group] for key in list(group_obs_spec.keys()): if key != "info": del group_obs_spec[key] return observation_spec def action_spec(self, env: EnvBase) -> CompositeSpec: return env.input_spec["full_action_spec"] @staticmethod def env_name() -> str: return "pettingzoo"