benchmarl.models.Cnn
- class Cnn(*args, **kwargs)[source]
Bases:
ModelConvolutional Neural Network (CNN) model.
The BenchMARL CNN accepts multiple inputs of 2 types:
images: Tensors of shape
(*batch,X,Y,C)arrays: Tensors of shape
(*batch,F)
The CNN model will check that all image inputs have the same shape (excluding the last dimension) and cat them along that dimension before processing them with
torchrl.modules.ConvNet.It will check that all array inputs have the same shape (excluding the last dimension) and cat them along that dimension.
It will then cat the arrays and processed images and feed them to the MLP together.
- Parameters:
cnn_num_cells (int or Sequence of int) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers
out_featureswill match the content of num_cells.cnn_kernel_sizes (int, sequence of int) – Kernel size(s) of the conv network. If iterable, the length must match the depth, defined by the
num_cellsor depth arguments.cnn_strides (int or sequence of int) – Stride(s) of the conv network. If iterable, the length must match the depth, defined by the
num_cellsor depth arguments.cnn_paddings – (int or Sequence of int): padding size for every layer.
cnn_activation_class (Type[nn.Module] or callable) – activation class or constructor to be used.
cnn_activation_kwargs (dict or list of dicts, optional) – kwargs to be used with the activation class. A list of kwargs of length
depthcan also be passed, with one element per layer.cnn_norm_class (Type or callable, optional) – normalization class or constructor, if any.
cnn_norm_kwargs (dict or list of dicts, optional) – kwargs to be used with the normalization layers. A list of kwargs of length
depthcan also be passed, with one element per layer.mlp_num_cells (int or Sequence[int]) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells.
mlp_layer_class (Type[nn.Module]) – class to be used for the linear layers;
mlp_activation_class (Type[nn.Module]) – activation class to be used.
mlp_activation_kwargs (dict, optional) – kwargs to be used with the activation class;
mlp_norm_class (Type, optional) – normalization class, if any.
mlp_norm_kwargs (dict, optional) – kwargs to be used with the normalization layers;
- _forward(tensordict: TensorDictBase) TensorDictBase[source]
Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.
- Parameters:
tensordict (TensorDictBase) – the input td
Returns: the input td with the written self.out_key