benchmarl.models.Lstm
- class Lstm(*args, **kwargs)[source]
Bases:
ModelA multi-layer Long Short-Term Memory (LSTM) RNN like the one from torch .
The BenchMARL LSTM accepts multiple inputs of type array: Tensors of shape
(*batch,F)Where F is the number of features. These arrays will be concatenated along the F dimensions, which will be processed to features of hidden_size by the LSTM.
- Parameters:
hidden_size (int) – The number of features in the hidden state.
num_layers (int) – Number of recurrent layers. E.g., setting
num_layers=2would mean stacking two lstms together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1bias (bool) – If
False, then the LSTM layers do not use bias. Default:Truedropout (float) – If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to
dropout. Default: 0compile (bool) – If
True, compiles underlying LSTM model. Default:False
- _forward(tensordict: TensorDictBase) TensorDictBase[source]
Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.
- Parameters:
tensordict (TensorDictBase) – the input td
Returns: the input td with the written self.out_key