Source code for benchmarl.environments.vmas.common

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#
import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
from torchrl.envs import EnvBase
from torchrl.envs.libs.vmas import VmasEnv

from benchmarl.environments.common import Task, TaskClass
from benchmarl.utils import DEVICE_TYPING


class VmasClass(TaskClass):
    def get_env_fun(
        self,
        num_envs: int,
        continuous_actions: bool,
        seed: Optional[int],
        device: DEVICE_TYPING,
    ) -> Callable[[], EnvBase]:
        config = copy.deepcopy(self.config)
        return lambda: VmasEnv(
            scenario=self.name.lower(),
            num_envs=num_envs,
            continuous_actions=continuous_actions,
            seed=seed,
            device=device,
            categorical_actions=True,
            clamp_actions=True,
            **config,
        )

    def supports_continuous_actions(self) -> bool:
        return True

    def supports_discrete_actions(self) -> bool:
        return True

    def has_render(self, env: EnvBase) -> bool:
        return True

    def max_steps(self, env: EnvBase) -> int:
        return self.config["max_steps"]

    def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
        if hasattr(env, "group_map"):
            return env.group_map
        return {"agents": [agent.name for agent in env.agents]}

    def state_spec(self, env: EnvBase) -> Optional[Composite]:
        return None

    def action_mask_spec(self, env: EnvBase) -> Optional[Composite]:
        return None

    def observation_spec(self, env: EnvBase) -> Composite:
        observation_spec = env.full_observation_spec_unbatched.clone()
        for group in self.group_map(env):
            if "info" in observation_spec[group]:
                del observation_spec[(group, "info")]
        return observation_spec

    def info_spec(self, env: EnvBase) -> Optional[Composite]:
        info_spec = env.full_observation_spec_unbatched.clone()
        for group in self.group_map(env):
            del info_spec[(group, "observation")]
        for group in self.group_map(env):
            if "info" in info_spec[group]:
                return info_spec
        else:
            return None

    def action_spec(self, env: EnvBase) -> Composite:
        return env.full_action_spec_unbatched

    @staticmethod
    def env_name() -> str:
        return "vmas"


[docs] class VmasTask(Task): """Enum for VMAS tasks.""" BALANCE = None SAMPLING = None NAVIGATION = None TRANSPORT = None REVERSE_TRANSPORT = None WHEEL = None DISPERSION = None MULTI_GIVE_WAY = None DROPOUT = None GIVE_WAY = None WIND_FLOCKING = None PASSAGE = None JOINT_PASSAGE = None JOINT_PASSAGE_SIZE = None BALL_PASSAGE = None BALL_TRAJECTORY = None BUZZ_WIRE = None FLOCKING = None DISCOVERY = None FOOTBALL = None SIMPLE_ADVERSARY = None SIMPLE_CRYPTO = None SIMPLE_PUSH = None SIMPLE_REFERENCE = None SIMPLE_SPEAKER_LISTENER = None SIMPLE_SPREAD = None SIMPLE_TAG = None SIMPLE_WORLD_COMM = None
[docs] @staticmethod def associated_class(): return VmasClass