Source code for benchmarl.models.gnn

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

import importlib
import inspect
import warnings
from dataclasses import dataclass, MISSING
from math import prod
from typing import List, Optional, Type

import torch
from tensordict import TensorDictBase
from tensordict.utils import _unravel_key_to_tuple, NestedKey
from torch import nn, Tensor

from benchmarl.models.common import Model, ModelConfig

_has_torch_geometric = importlib.util.find_spec("torch_geometric") is not None
if _has_torch_geometric:
    import torch_geometric
    from torch_geometric.transforms import BaseTransform

    class _RelVel(BaseTransform):
        """Transform that reads graph.vel and writes node1.vel - node2.vel in the edge attributes"""

        def __init__(self):
            pass

        def forward(self, data):
            (row, col), vel, pseudo = data.edge_index, data.vel, data.edge_attr

            cart = vel[row] - vel[col]
            cart = cart.view(-1, 1) if cart.dim() == 1 else cart

            if pseudo is not None:
                pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
                data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
            else:
                data.edge_attr = cart
            return data


TOPOLOGY_TYPES = {"full", "empty", "from_pos"}


[docs] class Gnn(Model): """A GNN model. GNN models can be used as "decentralized" actors or critics. Args: topology (str): Topology of the graph adjacency matrix. Options: "full", "empty", "from_pos". "from_pos" builds the topology dynamically based on ``position_key`` and ``edge_radius``. self_loops (str): Whether the resulting adjacency matrix will have self loops. gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env (in the `observation_spec`) representing the agent position. To do this, your environment needs to have dictionary observations and one of the keys needs to be `position_key`. This key will be processed as a node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a one-dimensional distance for all neighbours in the graph. If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence. pos_features (int, optional): Needed when position_key is specified. It has to match to the last element of the shape the tensor under position_key. exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided, wether to use it just to compute edge features or also include it in node features. velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env (in the `observation_spec`) representing the agent position. To do this, your environment needs to have dictionary observations and one of the keys needs to be `velocity_key`. This key will be processed as a node feature, and it will be used to construct edge features. In particular, it will be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours in the graph. If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence. vel_features (int, optional): Needed when velocity_key is specified. It has to match to the last element of the shape the tensor under velocity_key. edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph. Agents within this radius distance will be neighnours. Examples: .. code-block:: python import torch_geometric from torch import nn from benchmarl.algorithms import IppoConfig from benchmarl.environments import VmasTask from benchmarl.experiment import Experiment, ExperimentConfig from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig experiment = Experiment( algorithm_config=IppoConfig.get_from_yaml(), model_config=GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GATv2Conv, gnn_kwargs={}, ), critic_model_config=SequenceModelConfig( model_configs=[ MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GraphConv, ), MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear), ], intermediate_sizes=[5,3], ), seed=0, config=ExperimentConfig.get_from_yaml(), task=VmasTask.NAVIGATION.get_from_yaml(), ) experiment.run() """ def __init__( self, topology: str, self_loops: bool, gnn_class: Type[torch_geometric.nn.MessagePassing], gnn_kwargs: Optional[dict], position_key: Optional[str], exclude_pos_from_node_features: Optional[bool], velocity_key: Optional[str], edge_radius: Optional[float], pos_features: Optional[int], vel_features: Optional[int], **kwargs, ): self.topology = topology self.self_loops = self_loops self.position_key = position_key self.velocity_key = velocity_key self.exclude_pos_from_node_features = exclude_pos_from_node_features self.edge_radius = edge_radius self.pos_features = pos_features self.vel_features = vel_features super().__init__(**kwargs) if self.pos_features > 0: self.pos_features += 1 # We will add also 1-dimensional distance self.edge_features = self.pos_features + self.vel_features self.input_features = sum( [ spec.shape[-1] for key, spec in self.input_spec.items(True, True) if _unravel_key_to_tuple(key)[-1] not in (position_key, velocity_key) ] ) # Input keys if self.position_key is not None and not self.exclude_pos_from_node_features: self.input_features += self.pos_features - 1 if self.velocity_key is not None: self.input_features += self.vel_features self.output_features = self.output_leaf_spec.shape[-1] if gnn_kwargs is None: gnn_kwargs = {} gnn_kwargs.update( {"in_channels": self.input_features, "out_channels": self.output_features} ) self.gnn_supports_edge_attrs = ( "edge_dim" in inspect.getfullargspec(gnn_class).args ) if ( self.position_key is not None or self.velocity_key is not None ) and not self.gnn_supports_edge_attrs: warnings.warn( "Position key or velocity key provided but GNN class does not support edge attributes. " "These keys will not be used for computing edge features." ) if ( position_key is not None or velocity_key is not None ) and self.gnn_supports_edge_attrs: gnn_kwargs.update({"edge_dim": self.edge_features}) self.gnns = nn.ModuleList( [ gnn_class(**gnn_kwargs).to(self.device) for _ in range(self.n_agents if not self.share_params else 1) ] ) self.edge_index = _get_edge_index( topology=self.topology, self_loops=self.self_loops, device=self.device, n_agents=self.n_agents, ) self._full_position_key = None self._full_velocity_key = None def _perform_checks(self): super()._perform_checks() if self.topology not in TOPOLOGY_TYPES: raise ValueError( f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}" ) if self.topology == "from_pos" and self.position_key is None: raise ValueError("If topology is from_pos, position_key must be provided") if ( self.position_key is not None and self.exclude_pos_from_node_features is None ): raise ValueError( "exclude_pos_from_node_features needs to be specified when position_key is provided" ) if self.position_key is not None and self.pos_features <= 0: raise ValueError( f"Position key specified but pos_features is {self.pos_features}" ) elif self.position_key is None and self.pos_features > 0: raise ValueError( f"If no position_key is given, pos_features needs to be 0, got: {self.pos_features}" ) if self.velocity_key is not None and self.vel_features <= 0: raise ValueError( f"Velocity key specified but vel_features is {self.vel_features}" ) elif self.velocity_key is None and self.vel_features > 0: raise ValueError( f"If no velocity_key is given, vel_features needs to be 0, got: {self.vel_features}" ) if not self.input_has_agent_dim: raise ValueError( "The GNN module is not compatible with input that does not have the agent dimension," "such as the global state in centralised critics. Please choose another critic model" "if your algorithm has a centralized critic and the task has a global state." "If you are using the GNN in a centralized critic, it should be the first layer." ) input_shape = None for input_key, input_spec in self.input_spec.items(True, True): if len(input_spec.shape) == 2: if input_shape is None: input_shape = input_spec.shape[:-1] else: if input_spec.shape[:-1] != input_shape: raise ValueError( f"GNN inputs should all have the same shape up to the last dimension, got {self.input_spec}" ) else: raise ValueError( f"GNN input value {input_key} from {self.input_spec} has an invalid shape" ) if input_shape[-1] != self.n_agents: raise ValueError( f"The second to last input spec dimension should be the number of agents, got {self.input_spec}" ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents ): raise ValueError( "If the GNN output has the agent dimension," " the second to last spec dimension should be the number of agents" )
[docs] def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key input = [ tensordict.get(in_key) for in_key in self.in_keys if _unravel_key_to_tuple(in_key)[-1] not in (self.position_key, self.velocity_key) ] # Retrieve position if self.position_key is not None: if self._full_position_key is None: # Run once to find full key self._full_position_key = self._get_key_terminating_with( list(tensordict.keys(True, True)), self.position_key ) pos = tensordict.get(self._full_position_key) if pos.shape[-1] != self.pos_features - 1: raise ValueError( f"Position key in tensordict is {pos.shape[-1]}-dimensional, " f"while model was configured with pos_features={self.pos_features-1}" ) else: pos = tensordict.get(self._full_position_key) if not self.exclude_pos_from_node_features: input.append(pos) else: pos = None # Retrieve velocity if self.velocity_key is not None: if self._full_velocity_key is None: # Run once to find full key self._full_velocity_key = self._get_key_terminating_with( list(tensordict.keys(True, True)), self.velocity_key ) vel = tensordict.get(self._full_velocity_key) if vel.shape[-1] != self.vel_features: raise ValueError( f"Velocity key in tensordict is {vel.shape[-1]}-dimensional, " f"while model was configured with vel_features={self.vel_features}" ) else: vel = tensordict.get(self._full_velocity_key) input.append(vel) else: vel = None input = torch.cat(input, dim=-1) batch_size = input.shape[:-2] graph = _batch_from_dense_to_ptg( x=input, edge_index=self.edge_index, pos=pos, vel=vel, self_loops=self.self_loops, edge_radius=self.edge_radius, ) forward_gnn_params = { "x": graph.x, "edge_index": graph.edge_index, } if ( self.position_key is not None or self.velocity_key is not None ) and self.gnn_supports_edge_attrs: forward_gnn_params.update({"edge_attr": graph.edge_attr}) if not self.share_params: if not self.centralised: res = torch.stack( [ gnn(**forward_gnn_params).view( *batch_size, self.n_agents, self.output_features, )[..., i, :] for i, gnn in enumerate(self.gnns) ], dim=-2, ) else: res = torch.stack( [ gnn(**forward_gnn_params) .view( *batch_size, self.n_agents, self.output_features, ) .mean(dim=-2) # Mean pooling for i, gnn in enumerate(self.gnns) ], dim=-2, ) else: res = self.gnns[0](**forward_gnn_params).view( *batch_size, self.n_agents, self.output_features ) if self.centralised: res = res.mean(dim=-2) # Mean pooling tensordict.set(self.out_key, res) return tensordict
def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey: for k in keys: k_tuple = _unravel_key_to_tuple(k) if ( k_tuple[-1] == key and self.agent_group in k_tuple and not "next" == k_tuple[0] ): return k raise KeyError( f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}. " f"If you are using the GNN in a `SequenceModel` and want to use this key, it needs to be the first model." )
def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str): if topology == "full": adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long) edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency) if not self_loops: edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) elif topology == "empty": if self_loops: edge_index = ( torch.arange(n_agents, device=device, dtype=torch.long) .unsqueeze(0) .repeat(2, 1) ) else: edge_index = torch.empty((2, 0), device=device, dtype=torch.long) elif topology == "from_pos": edge_index = None else: raise ValueError(f"Topology {topology} not supported") return edge_index def _batch_from_dense_to_ptg( x: Tensor, edge_index: Optional[Tensor], self_loops: bool, pos: Tensor = None, vel: Tensor = None, edge_radius: Optional[float] = None, ) -> torch_geometric.data.Batch: batch_size = prod(x.shape[:-2]) n_agents = x.shape[-2] x = x.view(-1, x.shape[-1]) if pos is not None: pos = pos.view(-1, pos.shape[-1]) if vel is not None: vel = vel.view(-1, vel.shape[-1]) b = torch.arange(batch_size, device=x.device) graphs = torch_geometric.data.Batch() graphs.ptr = torch.arange(0, (batch_size + 1) * n_agents, n_agents) graphs.batch = torch.repeat_interleave(b, n_agents) graphs.x = x graphs.pos = pos graphs.vel = vel graphs.edge_attr = None if edge_index is not None: n_edges = edge_index.shape[1] # Tensor of shape [batch_size * n_edges] # in which edges corresponding to the same graph have the same index. batch = torch.repeat_interleave(b, n_edges) # Edge index for the batched graphs of shape [2, n_edges * batch_size] # we sum to each batch an offset of batch_num * n_agents to make sure that # the adjacency matrices remain independent batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents graphs.edge_index = batch_edge_index else: if pos is None: raise RuntimeError("from_pos topology needs positions as input") graphs.edge_index = torch_geometric.nn.pool.radius_graph( graphs.pos, batch=graphs.batch, r=edge_radius, loop=self_loops ) graphs = graphs.to(x.device) if pos is not None: graphs = torch_geometric.transforms.Cartesian(norm=False)(graphs) graphs = torch_geometric.transforms.Distance(norm=False)(graphs) if vel is not None: graphs = _RelVel()(graphs) return graphs
[docs] @dataclass class GnnConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.Gnn`.""" topology: str = MISSING self_loops: bool = MISSING gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING gnn_kwargs: Optional[dict] = None position_key: Optional[str] = None pos_features: Optional[int] = 0 velocity_key: Optional[str] = None vel_features: Optional[int] = 0 exclude_pos_from_node_features: Optional[bool] = None edge_radius: Optional[float] = None
[docs] @staticmethod def associated_class(): return Gnn