benchmarl.models.Lstm

class Lstm(*args, **kwargs)[source]

Bases: Model

A multi-layer Long Short-Term Memory (LSTM) RNN like the one from torch .

The BenchMARL LSTM accepts multiple inputs of type array: Tensors of shape (*batch,F)

Where F is the number of features. These arrays will be concatenated along the F dimensions, which will be processed to features of hidden_size by the LSTM.

Parameters:
  • hidden_size (int) – The number of features in the hidden state.

  • num_layers (int) – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two lstms together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1

  • bias (bool) – If False, then the LSTM layers do not use bias. Default: True

  • dropout (float) – If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0

  • compile (bool) – If True, compiles underlying LSTM model. Default: False

_forward(tensordict: TensorDictBase) TensorDictBase[source]

Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.

Parameters:

tensordict (TensorDictBase) – the input td

Returns: the input td with the written self.out_key