benchmarl.algorithms.Algorithm

class Algorithm(experiment)[source]

Bases: ABC

Abstract class for an algorithm. This should be overridden by implemented algorithms and all abstract methods should be implemented.

Parameters:

experiment (Experiment) – the experiment class

_check_specs()[source]
get_loss_and_updater(group: str) Tuple[LossModule, TargetNetUpdater][source]

Get the LossModule and TargetNetUpdater for a specific group. This function calls the abstract _get_loss() which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.

Parameters:

group (str) – agent group of the loss and updater

Returns: LossModule and TargetNetUpdater for the group

get_replay_buffer(group: str, transforms: List[Transform] | None = None) ReplayBuffer[source]

Get the ReplayBuffer for a specific group. This function will check self.on_policy and create the buffer accordingly

Parameters:
  • group (str) – agent group of the loss and updater

  • transforms (optional, list of Transform) – Transforms to apply to the replay buffer .sample() call

Returns: ReplayBuffer the group

get_policy_for_loss(group: str) TensorDictModule[source]

Get the non-explorative policy for a specific group loss. This function calls the abstract _get_policy_for_loss() which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.

Parameters:

group (str) – agent group of the policy

Returns: TensorDictModule representing the policy

get_policy_for_collection() TensorDictSequential[source]

Get the explorative policy for all groups together. This function calls the abstract _get_policy_for_collection() which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.

Returns: TensorDictSequential representing all explorative policies

get_parameters(group: str) Dict[str, Iterable][source]

Get the dictionary mapping loss names to the relative parameters to optimize for a given group. This function calls the abstract _get_parameters() which needs to be implemented.

Returns: a dictionary mapping loss names to a parameters’ list

process_env_fun(env_fun: Callable[[], EnvBase]) Callable[[], EnvBase][source]

This function can be used to wrap env_fun

Parameters:

env_fun (callable) – a function that takes no args and creates an enviornment

Returns: a function that takes no args and creates an enviornment

abstract _get_loss(group: str, policy_for_loss: TensorDictModule, continuous: bool) Tuple[LossModule, bool][source]

Implement this function to return the LossModule for a specific group.

Parameters:
  • group (str) – agent group of the loss

  • policy_for_loss (TensorDictModule) – the policy to use in the loss

  • continuous (bool) – whether to return a loss for continuous or discrete actions

Returns: LossModule and a bool representing if the loss should have target parameters

abstract _get_parameters(group: str, loss: LossModule) Dict[str, Iterable][source]

Get the dictionary mapping loss names to the relative parameters to optimize for a given group loss.

Returns: a dictionary mapping loss names to a parameters’ list

abstract _get_policy_for_loss(group: str, model_config: ModelConfig, continuous: bool) TensorDictModule[source]

Get the non-explorative policy for a specific group.

Parameters:
  • group (str) – agent group of the policy

  • model_config (ModelConfig) – model config class

  • continuous (bool) – whether the policy should be continuous or discrete

Returns: TensorDictModule representing the policy

abstract _get_policy_for_collection(policy_for_loss: TensorDictModule, group: str, continuous: bool) TensorDictModule[source]

Implement this function to add an explorative layer to the policy used in the loss.

Parameters:
  • policy_for_loss (TensorDictModule) – the group policy used in the loss

  • group (str) – agent group

  • continuous (bool) – whether the policy is continuous or discrete

Returns: TensorDictModule representing the explorative policy

abstract process_batch(group: str, batch: TensorDictBase) TensorDictBase[source]

This function can be used to reshape data coming from collection before it is passed to the policy.

Parameters:
  • group (str) – agent group

  • batch (TensorDictBase) – the batch of data coming from the collector

Returns: the processed batch

process_loss_vals(group: str, loss_vals: TensorDictBase) TensorDictBase[source]

Here you can modify the loss_vals tensordict containing entries loss_name->loss_value For example, you can sum two entries in a new entry, to optimize them together.

Parameters:
  • group (str) – agent group

  • loss_vals (TensorDictBase) – the tensordict returned by the loss forward method

Returns: the processed loss_vals

_abc_impl = <_abc._abc_data object>