Source code for benchmarl.models.lstm

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

from dataclasses import dataclass, MISSING
from typing import Optional, Sequence, Type

import torch
import torch.nn.functional as F
from tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, unravel_key_list
from torch import nn
from torchrl.data.tensor_specs import Composite, Unbounded

from torchrl.modules import LSTMCell, MLP, MultiAgentMLP

from benchmarl.models.common import Model, ModelConfig
from benchmarl.utils import DEVICE_TYPING


class LSTM(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        device: DEVICE_TYPING,
        n_layers: int,
        dropout: float,
        bias: bool,
        time_dim: int = -2,
    ):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.time_dim = time_dim
        self.n_layers = n_layers
        self.dropout = dropout
        self.bias = bias

        self.lstms = torch.nn.ModuleList(
            [
                LSTMCell(
                    input_size if i == 0 else hidden_size,
                    hidden_size,
                    device=self.device,
                    bias=self.bias,
                )
                for i in range(self.n_layers)
            ]
        )

    def forward(self, input, is_init, h, c):
        hs = []

        h = list(h.unbind(dim=-2))
        c = list(c.unbind(dim=-2))

        for in_t, init_t in zip(
            input.unbind(self.time_dim), is_init.unbind(self.time_dim)
        ):
            for layer in range(self.n_layers):
                h[layer] = torch.where(init_t, 0, h[layer])
                c[layer] = torch.where(init_t, 0, c[layer])

                h[layer], c[layer] = self.lstms[layer](in_t, (h[layer], c[layer]))

                if layer < self.n_layers - 1 and self.dropout:
                    in_t = F.dropout(h[layer], p=self.dropout, training=self.training)
                else:
                    in_t = h[layer]

            hs.append(in_t)
        h_n = torch.stack(h, dim=-2)
        c_n = torch.stack(c, dim=-2)
        output = torch.stack(hs, self.time_dim)

        return output, h_n, c_n


def get_net(input_size, hidden_size, n_layers, bias, device, dropout, compile):
    lstm = LSTM(
        input_size,
        hidden_size,
        n_layers=n_layers,
        bias=bias,
        device=device,
        dropout=dropout,
    )
    if compile:
        lstm = torch.compile(lstm, mode="reduce-overhead")
    return lstm


class MultiAgentLSTM(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        n_agents: int,
        device: DEVICE_TYPING,
        centralised: bool,
        share_params: bool,
        n_layers: int,
        dropout: float,
        bias: bool,
        compile: bool,
    ):
        super().__init__()
        self.input_size = input_size
        self.n_agents = n_agents
        self.hidden_size = hidden_size
        self.device = device
        self.centralised = centralised
        self.share_params = share_params
        self.n_layers = n_layers
        self.bias = bias
        self.dropout = dropout
        self.compile = compile

        if self.centralised:
            input_size = input_size * self.n_agents

        agent_networks = [
            get_net(
                input_size=input_size,
                hidden_size=self.hidden_size,
                n_layers=self.n_layers,
                bias=self.bias,
                device=self.device,
                dropout=self.dropout,
                compile=self.compile,
            )
            for _ in range(self.n_agents if not self.share_params else 1)
        ]
        self._make_params(agent_networks)

        with torch.device("meta"):
            self._empty_lstm = get_net(
                input_size=input_size,
                hidden_size=self.hidden_size,
                n_layers=self.n_layers,
                bias=self.bias,
                device="meta",
                dropout=self.dropout,
                compile=self.compile,
            )
            # Remove all parameters
            TensorDict.from_module(self._empty_lstm).data.to("meta").to_module(
                self._empty_lstm
            )

    def forward(
        self,
        input,
        is_init,
        h_0=None,
        c_0=None,
    ):
        # Input and output always have the multiagent dimension
        # Hidden states always have it apart from when it is centralized and share params
        # is_init never has it

        assert is_init is not None, "We need to pass is_init"
        training = h_0 is None

        missing_batch = False
        if (
            not training and len(input.shape) < 3
        ):  # In evaluation the batch might be missing
            missing_batch = True
            input = input.unsqueeze(0)
            h_0 = h_0.unsqueeze(0)
            c_0 = c_0.unsqueeze(0)
            is_init = is_init.unsqueeze(0)

        if (
            not training
        ):  # In collection we emulate the sequence dimension and we have the hidden state
            input = input.unsqueeze(1)

        # Check input
        batch = input.shape[0]
        seq = input.shape[1]
        assert input.shape == (batch, seq, self.n_agents, self.input_size)

        if not training:  # Collection
            # Set hidden to 0 when is_init
            h_0 = torch.where(expand_as_right(is_init, h_0), 0, h_0)
            c_0 = torch.where(expand_as_right(is_init, c_0), 0, c_0)
            is_init = is_init.unsqueeze(
                1
            )  # If in collection emulate the sequence dimension

        assert is_init.shape == (batch, seq, 1)
        is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1)

        if training:
            if self.centralised and self.share_params:
                shape = (
                    batch,
                    self.n_layers,
                    self.hidden_size,
                )
            else:
                shape = (
                    batch,
                    self.n_agents,
                    self.n_layers,
                    self.hidden_size,
                )
            h_0 = torch.zeros(
                shape,
                device=self.device,
                dtype=torch.float,
            )
            c_0 = h_0.clone()
        if self.centralised:
            input = input.view(batch, seq, self.n_agents * self.input_size)
            is_init = is_init[..., 0, :]

        output, h_n, c_n = self.run_net(input, is_init, h_0, c_0)

        if self.centralised and self.share_params:
            output = output.unsqueeze(-2).expand(
                batch, seq, self.n_agents, self.hidden_size
            )

        if not training:
            output = output.squeeze(1)
        if missing_batch:
            output = output.squeeze(0)
            h_n = h_n.squeeze(0)
            c_n = c_n.squeeze(0)
        return output, h_n, c_n

    def run_net(self, input, is_init, h_0, c_0):
        if not self.share_params:
            if self.centralised:
                output, h_n, c_n = self.vmap_func_module(
                    self._empty_lstm,
                    (0, None, None, -3, -3),
                    (-2, -3, -3),
                )(self.params, input, is_init, h_0, c_0)
            else:
                output, h_n, c_n = self.vmap_func_module(
                    self._empty_lstm,
                    (0, -2, -2, -3, -3),
                    (-2, -3, -3),
                )(self.params, input, is_init, h_0, c_0)
        else:
            with self.params.to_module(self._empty_lstm):
                if self.centralised:
                    output, h_n, c_n = self._empty_lstm(input, is_init, h_0, c_0)
                else:
                    output, h_n, c_n = torch.vmap(
                        self._empty_lstm,
                        in_dims=(-2, -2, -3, -3),
                        out_dims=(-2, -3, -3),
                    )(input, is_init, h_0, c_0)

        return output, h_n, c_n

    def vmap_func_module(self, module, *args, **kwargs):
        def exec_module(params, *input):
            with params.to_module(module):
                return module(*input)

        return torch.vmap(exec_module, *args, **kwargs)

    def _make_params(self, agent_networks):
        if self.share_params:
            self.params = TensorDict.from_module(agent_networks[0], as_module=True)
        else:
            self.params = TensorDict.from_modules(*agent_networks, as_module=True)


[docs] class Lstm(Model): r"""A multi-layer Long Short-Term Memory (LSTM) RNN like the one from `torch <https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html>`__ . The BenchMARL LSTM accepts multiple inputs of type array: Tensors of shape ``(*batch,F)`` Where `F` is the number of features. These arrays will be concatenated along the F dimensions, which will be processed to features of `hidden_size` by the LSTM. Args: hidden_size (int): The number of features in the hidden state. num_layers (int): Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two lstms together to form a `stacked LSTM`, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 bias (bool): If ``False``, then the LSTM layers do not use bias. Default: ``True`` dropout (float): If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 compile (bool): If ``True``, compiles underlying LSTM model. Default: ``False`` """ def __init__( self, hidden_size: int, n_layers: int, bias: bool, dropout: float, compile: bool, **kwargs, ): super().__init__( input_spec=kwargs.pop("input_spec"), output_spec=kwargs.pop("output_spec"), agent_group=kwargs.pop("agent_group"), input_has_agent_dim=kwargs.pop("input_has_agent_dim"), n_agents=kwargs.pop("n_agents"), centralised=kwargs.pop("centralised"), share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), model_index=kwargs.pop("model_index"), is_critic=kwargs.pop("is_critic"), ) self.hidden_state_name_h = ( self.agent_group, f"_hidden_lstm_h_{self.model_index}", ) self.hidden_state_name_c = ( self.agent_group, f"_hidden_lstm_c_{self.model_index}", ) self.rnn_keys = unravel_key_list( ["is_init", self.hidden_state_name_c, self.hidden_state_name_h] ) self.in_keys += self.rnn_keys self.hidden_size = hidden_size self.n_layers = n_layers self.bias = bias self.dropout = dropout self.compile = compile self.input_features = sum( [spec.shape[-1] for spec in self.input_spec.values(True, True)] ) self.output_features = self.output_leaf_spec.shape[-1] if self.input_has_agent_dim: self.lstm = MultiAgentLSTM( self.input_features, self.hidden_size, self.n_agents, self.device, bias=self.bias, n_layers=self.n_layers, centralised=self.centralised, share_params=self.share_params, dropout=self.dropout, compile=self.compile, ) else: self.lstm = nn.ModuleList( [ get_net( input_size=self.input_features, hidden_size=self.hidden_size, n_layers=self.n_layers, bias=self.bias, device=self.device, dropout=self.dropout, compile=self.compile, ) for _ in range(self.n_agents if not self.share_params else 1) ] ) mlp_net_kwargs = { "_".join(k.split("_")[1:]): v for k, v in kwargs.items() if k.startswith("mlp_") } if self.output_has_agent_dim: self.mlp = MultiAgentMLP( n_agent_inputs=self.hidden_size, n_agent_outputs=self.output_features, n_agents=self.n_agents, centralised=self.centralised, share_params=self.share_params, device=self.device, **mlp_net_kwargs, ) else: self.mlp = nn.ModuleList( [ MLP( in_features=self.hidden_size, out_features=self.output_features, device=self.device, **mlp_net_kwargs, ) for _ in range(self.n_agents if not self.share_params else 1) ] ) def _perform_checks(self): super()._perform_checks() input_shape = None for input_key, input_spec in self.input_spec.items(True, True): if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( not self.input_has_agent_dim and len(input_spec.shape) == 1 ): if input_shape is None: input_shape = input_spec.shape[:-1] else: if input_spec.shape[:-1] != input_shape: raise ValueError( f"LSTM inputs should all have the same shape up to the last dimension, got {self.input_spec}" ) else: raise ValueError( f"LSTM input value {input_key} from {self.input_spec} has an invalid shape, maybe you need a CNN?" ) if self.input_has_agent_dim: if input_shape[-1] != self.n_agents: raise ValueError( "If the LSTM input has the agent dimension," f" the second to last spec dimension should be the number of agents, got {self.input_spec}" ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents ): raise ValueError( "If the LSTM output has the agent dimension," " the second to last spec dimension should be the number of agents" )
[docs] def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key input = torch.cat( [ tensordict.get(in_key) for in_key in self.in_keys if in_key not in self.rnn_keys ], dim=-1, ) h_0 = tensordict.get(self.hidden_state_name_h, None) c_0 = tensordict.get(self.hidden_state_name_c, None) is_init = tensordict.get("is_init") training = h_0 is None # Has multi-agent input dimension if self.input_has_agent_dim: output, h_n, c_n = self.lstm(input, is_init, h_0, c_0) if not self.output_has_agent_dim: output = output[..., 0, :] else: # Is a global input, this is a critic # Check input batch = input.shape[0] seq = input.shape[1] assert input.shape == (batch, seq, self.input_features) assert is_init.shape == (batch, seq, 1) h_0 = torch.zeros( (batch, self.n_layers, self.hidden_size), device=self.device, dtype=torch.float, ) c_0 = h_0.clone() if self.share_params: output, _, _ = self.lstm[0](input, is_init, h_0, c_0) else: outputs = [] for net in self.lstm: output, _, _ = net(input, is_init, h_0, c_0) outputs.append(output) output = torch.stack(outputs, dim=-2) # Mlp if self.output_has_agent_dim: output = self.mlp.forward(output) else: if not self.share_params: output = torch.stack( [net(output) for net in self.mlp], dim=-2, ) else: output = self.mlp[0](output) tensordict.set(self.out_key, output) if not training: tensordict.set(("next", *self.hidden_state_name_h), h_n) tensordict.set(("next", *self.hidden_state_name_c), c_n) return tensordict
[docs] @dataclass class LstmConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.LSTM`.""" hidden_size: int = MISSING n_layers: int = MISSING bias: bool = MISSING dropout: float = MISSING compile: bool = MISSING mlp_num_cells: Sequence[int] = MISSING mlp_layer_class: Type[nn.Module] = MISSING mlp_activation_class: Type[nn.Module] = MISSING mlp_activation_kwargs: Optional[dict] = None mlp_norm_class: Type[nn.Module] = None mlp_norm_kwargs: Optional[dict] = None
[docs] @staticmethod def associated_class(): return Lstm
@property def is_rnn(self) -> bool: return True
[docs] def get_model_state_spec(self, model_index: int = 0) -> Composite: spec = Composite( { f"_hidden_lstm_c_{model_index}": Unbounded( shape=(self.n_layers, self.hidden_size) ), f"_hidden_lstm_h_{model_index}": Unbounded( shape=(self.n_layers, self.hidden_size) ), } ) return spec