benchmarl.models.Deepsets

class Deepsets(*args, **kwargs)[source]

Bases: Model

Deepsets Model from this paper .

The BenchMARL Deepsets accepts multiple inputs of 2 types:

  • sets \(s\) : Tensors of shape (*batch,S,F)

  • arrays \(x\) : Tensors of shape (*batch,F)

The Deepsets model will check that all set inputs have the same shape (excluding the last dimension) and cat them along that dimension before processing them.

It will check that all array inputs have the same shape (excluding the last dimension) and cat them along that dimension.

It will then compute the output according to the following function.

\[\rho \left (x, \bigoplus_{s\in S}\phi(s) \right ),\]

where \(\rho,\phi\) are MLPs configurable in the model setup.

The model is useful in various contexts, for example:

  • When used as a policy (self.centralized==False, self.input_has_agent_dim==True), it can process observations with shape (*batch,n_agents,S,F), reducing them to (*batch,n_agents,F)

  • When used a a centralized crtic with a global state as input (self.centralized==True, self.input_has_agent_dim==False), it can process the global state with shape (*batch,S,F) , reducing it to (*batch,F).

  • When used a a centralized crtic with local agent observations as input (self.centralized==True, self.input_has_agent_dim==True), it can process normal agent observations with shape (*batch,n_agents,F), reducing them to (*batch,F). Note: If the agents also have set observations (*batch,n_agents,S,F) it will apply two deep sets networks. The first will remove the set dimension in the agents’ inputs ((*batch,n_agents,F)), and the second will remove the agent dimension ((*batch,F)). Both networks will share the same configuration.

Parameters:
  • aggr (str) – The aggregation strategy to use in the Deepsets model.

  • local_nn_num_cells (Sequence[int]) – number of cells of every layer in between the input and output in the \(\phi\) MLP.

  • local_nn_activation_class (Type[nn.Module]) – activation class to be used in the \(\phi\) MLP.

  • out_features_local_nn (int) – output features of the \(\phi\) MLP.

  • global_nn_num_cells (Sequence[int]) – number of cells of every layer in between the input and output in the \(\rho\) MLP.

  • global_nn_activation_class (Type[nn.Module]) – activation class to be used in the \(\rho\) MLP.

_forward(tensordict: TensorDictBase) TensorDictBase[source]

Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.

Parameters:

tensordict (TensorDictBase) – the input td

Returns: the input td with the written self.out_key