benchmarl.algorithms.Masac

class Masac(share_param_critic: bool, num_qvalue_nets: int, loss_function: str, delay_qvalue: bool, target_entropy: float | str, discrete_target_entropy_weight: float, alpha_init: float, min_alpha: float | None, max_alpha: float | None, fixed_alpha: bool, scale_mapping: str, use_tanh_normal: bool, coupled_discrete_values: bool, **kwargs)[source]

Bases: Algorithm

Multi Agent Soft Actor Critic.

Parameters:
  • share_param_critic (bool) – Whether to share the parameters of the critics withing agent groups

  • num_qvalue_nets (integer) – number of Q-Value networks used.

  • loss_function (str) – loss function to be used with the value function loss.

  • delay_qvalue (bool) – Whether to separate the target Q value networks from the Q value networks used for data collection.

  • target_entropy (float or str, optional) – Target entropy for the stochastic policy. Default is “auto”, where target entropy is computed as -prod(n_actions).

  • discrete_target_entropy_weight (float) – weight for the target entropy term when actions are discrete

  • alpha_init (float) – initial entropy multiplier.

  • min_alpha (float) – min value of alpha.

  • max_alpha (float) – max value of alpha.

  • fixed_alpha (bool) – if True, alpha will be fixed to its initial value. Otherwise, alpha will be optimized to match the ‘target_entropy’ value.

  • scale_mapping (str) – positive mapping function to be used with the std. choices: “softplus”, “exp”, “relu”, “biased_softplus_1”;

  • use_tanh_normal (bool) – if True, use TanhNormal as the continuyous action distribution with support bound to the action domain. Otherwise, an IndependentNormal is used.

  • coupled_discrete_values (bool) – only relevant for discrete action spaces. if True, the critic will predict n_agents x n_actions action values given the global state (or concatenation of agents’ observations). if False, the critic will predict n_actions values given the global state and the actions of the other agents. This is done for all agents in parallel. True is more theoretically sound and should be preferred. However, if the number of outputs gets too large, you may want to try False.

_get_loss(group: str, policy_for_loss: TensorDictModule, continuous: bool) Tuple[LossModule, bool][source]

Implement this function to return the LossModule for a specific group.

Parameters:
  • group (str) – agent group of the loss

  • policy_for_loss (TensorDictModule) – the policy to use in the loss

  • continuous (bool) – whether to return a loss for continuous or discrete actions

Returns: LossModule and a bool representing if the loss should have target parameters

_get_parameters(group: str, loss: LossModule) Dict[str, Iterable][source]

Get the dictionary mapping loss names to the relative parameters to optimize for a given group loss.

Returns: a dictionary mapping loss names to a parameters’ list

_get_policy_for_loss(group: str, model_config: ModelConfig, continuous: bool) TensorDictModule[source]

Get the non-explorative policy for a specific group.

Parameters:
  • group (str) – agent group of the policy

  • model_config (ModelConfig) – model config class

  • continuous (bool) – whether the policy should be continuous or discrete

Returns: TensorDictModule representing the policy

_get_policy_for_collection(policy_for_loss: TensorDictModule, group: str, continuous: bool) TensorDictModule[source]

Implement this function to add an explorative layer to the policy used in the loss.

Parameters:
  • policy_for_loss (TensorDictModule) – the group policy used in the loss

  • group (str) – agent group

  • continuous (bool) – whether the policy is continuous or discrete

Returns: TensorDictModule representing the explorative policy

process_batch(group: str, batch: TensorDictBase) TensorDictBase[source]

This function can be used to reshape data coming from collection before it is passed to the policy.

Parameters:
  • group (str) – agent group

  • batch (TensorDictBase) – the batch of data coming from the collector

Returns: the processed batch

get_discrete_value_module_coupled(group: str) TensorDictModule[source]
get_discrete_value_module_decoupled(group: str) TensorDictModule[source]
get_continuous_value_module(group: str) TensorDictModule[source]
_abc_impl = <_abc._abc_data object>