benchmarl.models.LstmConfig
- class LstmConfig(hidden_size: int, n_layers: int, bias: bool, dropout: float, compile: bool, mlp_num_cells: Sequence[int], mlp_layer_class: Type[Module], mlp_activation_class: Type[Module], mlp_activation_kwargs: dict | None = None, mlp_norm_class: Type[Module] | None = None, mlp_norm_kwargs: dict | None = None)[source]
Bases:
ModelConfigDataclass config for a
LSTM.- get_model_state_spec(model_index: int = 0) Composite[source]
Get additional specs needed by the model as input.
This method is useful for adding recurrent states.
The returned value should be key: spec with the desired ending shape.
The batch and agent dimensions will automatically be added to the spec.
- Parameters:
model_index (int, optional) – the index of the model. Defaults to 0.