benchmarl.models.LstmConfig

class LstmConfig(hidden_size: int, n_layers: int, bias: bool, dropout: float, compile: bool, mlp_num_cells: Sequence[int], mlp_layer_class: Type[Module], mlp_activation_class: Type[Module], mlp_activation_kwargs: dict | None = None, mlp_norm_class: Type[Module] | None = None, mlp_norm_kwargs: dict | None = None)[source]

Bases: ModelConfig

Dataclass config for a LSTM.

static associated_class()[source]

The associated Model class

property is_rnn: bool

Whether the model is an RNN

get_model_state_spec(model_index: int = 0) Composite[source]

Get additional specs needed by the model as input.

This method is useful for adding recurrent states.

The returned value should be key: spec with the desired ending shape.

The batch and agent dimensions will automatically be added to the spec.

Parameters:

model_index (int, optional) – the index of the model. Defaults to 0.