benchmarl.models.Gnn
- class Gnn(*args, **kwargs)[source]
Bases:
ModelA GNN model.
GNN models can be used as “decentralized” actors or critics.
- Parameters:
topology (str) – Topology of the graph adjacency matrix. Options: “full”, “empty”.
self_loops (str) – Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]) – the gnn convolution class to use
gnn_kwargs (dict, optional) – the dict of arguments to pass to the gnn conv class
Examples
import torch_geometric from torch import nn from benchmarl.algorithms import IppoConfig from benchmarl.environments import VmasTask from benchmarl.experiment import Experiment, ExperimentConfig from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig experiment = Experiment( algorithm_config=IppoConfig.get_from_yaml(), model_config=GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GATv2Conv, gnn_kwargs={}, ), critic_model_config=SequenceModelConfig( model_configs=[ MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), GnnConfig( topology="full", self_loops=False, gnn_class=torch_geometric.nn.conv.GraphConv, ), MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear), ], intermediate_sizes=[5,3], ), seed=0, config=ExperimentConfig.get_from_yaml(), task=VmasTask.NAVIGATION.get_from_yaml(), ) experiment.run()
- _forward(tensordict: TensorDictBase) TensorDictBase[source]
Method to implement for the forward pass of the model. It should read self.in_key, process it and write self.out_key.
- Parameters:
tensordict (TensorDictBase) – the input td
Returns: the input td with the written self.out_key