benchmarl.models.Model
- class Model(*args, **kwargs)[source]
Bases:
TensorDictModuleBase,ABCAbstract class representing a model.
Models in BenchMARL are instantiated per agent group. This means that each model will process the inputs for a whole group of agents They are associated with input and output specs that define their domains.
- Parameters:
input_spec (Composite) – the input spec of the model
output_spec (Composite) – the output spec of the model
agent_group (str) – the name of the agent group the model is for
n_agents (int) – the number of agents this module is for
device (str) – the model’s device
input_has_agent_dim (bool) – This tells the model if the input will have a multi-agent dimension or not. For example, the input of policies will always have this set to true, but critics that use a global state have this set to false as the state is shared by all agents
centralised (bool) – This tells the model if it has full observability. This will always be true when
self.input_has_agent_dim==False, but in cases where the input has the agent dimension, this parameter is used to distinguish between a decentralised model (where each agent’s data is processed separately) and a centralized model, where the model pools all data togethershare_params (bool) – This tells the model if it should have only one set of parameters or a different set of parameters for each agent. This is independent of the other options as it is possible to have different parameters for centralized critics with global input.
action_spec (Composite) – The action spec of the environment
model_index (int) – the index of the model in a sequence
is_critic (bool) – Whether the model is a critic
- property output_has_agent_dim: bool
This is a dynamically computed attribute that indicates if the output will have the agent dimension. This will be false when
share_params==True and centralised==True, and true in all other cases. When output_has_agent_dim is true, your model’s output should contain the multi-agent dimension, and the dimension should be absent otherwise
- forward(tensordict: TensorDictBase) TensorDictBase[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Share paramters with another identical model model.
This function modifies in-place the parameters of
other_modelto reference the parameters ofself- Parameters:
other_model (Model) – the model that will share the parameters of
self.
- abstract _forward(tensordict: TensorDictBase) TensorDictBase[source]
Method to implement for the forward pass of the model. It should read self.in_keys, process it and write self.out_key.
- Parameters:
tensordict (TensorDictBase) – the input td
Returns: the input td with the written self.out_key