Source code for benchmarl.environments.vmas.common

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase
from torchrl.envs.libs.vmas import VmasEnv

from benchmarl.environments.common import Task
from benchmarl.utils import DEVICE_TYPING


[docs] class VmasTask(Task): """Enum for VMAS tasks.""" BALANCE = None SAMPLING = None NAVIGATION = None TRANSPORT = None REVERSE_TRANSPORT = None WHEEL = None DISPERSION = None DROPOUT = None GIVE_WAY = None WIND_FLOCKING = None SIMPLE_ADVERSARY = None SIMPLE_CRYPTO = None SIMPLE_PUSH = None SIMPLE_REFERENCE = None SIMPLE_SPEAKER_LISTENER = None SIMPLE_SPREAD = None SIMPLE_TAG = None SIMPLE_WORLD_COMM = None def get_env_fun( self, num_envs: int, continuous_actions: bool, seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: return lambda: VmasEnv( scenario=self.name.lower(), num_envs=num_envs, continuous_actions=continuous_actions, seed=seed, device=device, categorical_actions=True, clamp_actions=True, **self.config, ) def supports_continuous_actions(self) -> bool: return True def supports_discrete_actions(self) -> bool: return True def has_render(self, env: EnvBase) -> bool: return True def max_steps(self, env: EnvBase) -> int: return self.config["max_steps"] def group_map(self, env: EnvBase) -> Dict[str, List[str]]: if hasattr(env, "group_map"): return env.group_map return {"agents": [agent.name for agent in env.agents]} def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]: return None def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: return None def observation_spec(self, env: EnvBase) -> CompositeSpec: observation_spec = env.unbatched_observation_spec.clone() for group in self.group_map(env): if "info" in observation_spec[group]: del observation_spec[(group, "info")] return observation_spec def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: info_spec = env.unbatched_observation_spec.clone() for group in self.group_map(env): del info_spec[(group, "observation")] for group in self.group_map(env): if "info" in info_spec[group]: return info_spec else: return None def action_spec(self, env: EnvBase) -> CompositeSpec: return env.unbatched_action_spec @staticmethod def env_name() -> str: return "vmas"