benchmarl.models.Gnn

class Gnn(*args, **kwargs)[source]

Bases: Model

A GNN model.

GNN models can be used as “decentralized” actors or critics.

Parameters:
  • topology (str) – Topology of the graph adjacency matrix. Options: “full”, “empty”.

  • self_loops (str) – Whether the resulting adjacency matrix will have self loops.

  • gnn_class (Type[torch_geometric.nn.MessagePassing]) – the gnn convolution class to use

  • gnn_kwargs (dict, optional) – the dict of arguments to pass to the gnn conv class

Examples

import torch_geometric
from torch import nn
from benchmarl.algorithms import IppoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import SequenceModelConfig, GnnConfig, MlpConfig
experiment = Experiment(
    algorithm_config=IppoConfig.get_from_yaml(),
    model_config=GnnConfig(
        topology="full",
        self_loops=False,
        gnn_class=torch_geometric.nn.conv.GATv2Conv,
        gnn_kwargs={},
    ),
    critic_model_config=SequenceModelConfig(
        model_configs=[
            MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear),
            GnnConfig(
                topology="full",
                self_loops=False,
                gnn_class=torch_geometric.nn.conv.GraphConv,
            ),
            MlpConfig(num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear),
        ],
        intermediate_sizes=[5,3],
    ),
    seed=0,
    config=ExperimentConfig.get_from_yaml(),
    task=VmasTask.NAVIGATION.get_from_yaml(),
)
experiment.run()
_forward(tensordict: TensorDictBase) TensorDictBase[source]

Method to implement for the forward pass of the model. It should read self.in_key, process it and write self.out_key.

Parameters:

tensordict (TensorDictBase) – the input td

Returns: the input td with the written self.out_key