benchmarl.models.Gnn
- class Gnn(*args, **kwargs)[source]
Bases:
ModelA GNN model.
GNN models can be used as “decentralized” actors or critics.
- Parameters:
topology (str) – Topology of the graph adjacency matrix. Options: “full”, “empty”, “from_pos”. “from_pos” builds the topology dynamically based on
position_keyandedge_radius.self_loops (str) – Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]) – the gnn convolution class to use
gnn_kwargs (dict, optional) – the dict of arguments to pass to the gnn conv class
position_key (str, optional) – if provided, it will need to match a leaf key in the tensordict coming from the env (in the observation_spec) representing the agent position. To do this, your environment needs to have dictionary observations and one of the keys needs to be position_key. This key will be processed as a node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. In particular, it will be used to compute relative positions (
pos_node_1 - pos_node_2) and a one-dimensional distance for all neighbours in the graph. If you want to use this feature in aSequenceModel, the GNN needs to be first in sequence.pos_features (int, optional) – Needed when position_key is specified. It has to match to the last element of the shape the tensor under position_key.
exclude_pos_from_node_features (optional, bool) – If
position_keyis provided, wether to use it just to compute edge features or also include it in node features.velocity_key (str, optional) – if provided, it will need to match a leaf key in the tensordict coming from the env (in the observation_spec) representing the agent position. To do this, your environment needs to have dictionary observations and one of the keys needs to be velocity_key. This key will be processed as a node feature, and it will be used to construct edge features. In particular, it will be used to compute relative velocities (
vel_node_1 - vel_node_2) for all neighbours in the graph. If you want to use this feature in aSequenceModel, the GNN needs to be first in sequence.vel_features (int, optional) – Needed when velocity_key is specified. It has to match to the last element of the shape the tensor under velocity_key.
edge_radius (float, optional) – If topology is
"from_pos"the radius to use to build the agent graph. Agents within this radius distance will be neighnours.
Examples
import torch_geometric from torch import nn from benchmarl.algorithms import IppoConfig from benchmarl.environments import VmasTask from benchmarl.experiment import Experiment, ExperimentConfig from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig experiment = Experiment( algorithm_config=IppoConfig.get_from_yaml(), model_config=GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GATv2Conv, gnn_kwargs={}, ), critic_model_config=SequenceModelConfig( model_configs=[ MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GraphConv, ), MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear), ], intermediate_sizes=[5,3], ), seed=0, config=ExperimentConfig.get_from_yaml(), task=VmasTask.NAVIGATION.get_from_yaml(), ) experiment.run()
- _forward(tensordict: TensorDictBase) TensorDictBase[source]
Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.
- Parameters:
tensordict (TensorDictBase) – the input td
Returns: the input td with the written self.out_key