benchmarl.models.Gru

class Gru(*args, **kwargs)[source]

Bases: Model

A multi-layer Gated Recurrent Unit (GRU) RNN like the one from torch .

The BenchMARL GRU accepts multiple inputs of type array: Tensors of shape (*batch,F)

Where F is the number of features. These arrays will be concatenated along the F dimensions, which will be processed to features of hidden_size by the GRU.

Parameters:
  • hidden_size (int) – The number of features in the hidden state.

  • num_layers (int) – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1

  • bias (bool) – If False, then the GRU layers do not use bias. Default: True

  • dropout (float) – If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0

  • compile (bool) – If True, compiles underlying gru model. Default: False

_forward(tensordict: TensorDictBase) TensorDictBase[source]

Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.

Parameters:

tensordict (TensorDictBase) – the input td

Returns: the input td with the written self.out_key