benchmarl.algorithms.Algorithm
- class Algorithm(experiment)[source]
Bases:
ABCAbstract class for an algorithm. This should be overridden by implemented algorithms and all abstract methods should be implemented.
- Parameters:
experiment (Experiment) – the experiment class
- get_loss_and_updater(group: str) Tuple[LossModule, TargetNetUpdater][source]
Get the LossModule and TargetNetUpdater for a specific group. This function calls the abstract
_get_loss()which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.- Parameters:
group (str) – agent group of the loss and updater
Returns: LossModule and TargetNetUpdater for the group
- get_replay_buffer(group: str, transforms: List[Transform] | None = None) ReplayBuffer[source]
Get the ReplayBuffer for a specific group. This function will check
self.on_policyand create the buffer accordingly- Parameters:
Returns: ReplayBuffer the group
- get_policy_for_loss(group: str) TensorDictModule[source]
Get the non-explorative policy for a specific group loss. This function calls the abstract
_get_policy_for_loss()which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.- Parameters:
group (str) – agent group of the policy
Returns: TensorDictModule representing the policy
- get_policy_for_collection() TensorDictSequential[source]
Get the explorative policy for all groups together. This function calls the abstract
_get_policy_for_collection()which needs to be implemented. The function will cache the output at the first call and return the cached values in future calls.Returns: TensorDictSequential representing all explorative policies
- get_parameters(group: str) Dict[str, Iterable][source]
Get the dictionary mapping loss names to the relative parameters to optimize for a given group. This function calls the abstract
_get_parameters()which needs to be implemented.Returns: a dictionary mapping loss names to a parameters’ list
- process_env_fun(env_fun: Callable[[], EnvBase]) Callable[[], EnvBase][source]
This function can be used to wrap env_fun
- Parameters:
env_fun (callable) – a function that takes no args and creates an enviornment
Returns: a function that takes no args and creates an enviornment
- abstract _get_loss(group: str, policy_for_loss: TensorDictModule, continuous: bool) Tuple[LossModule, bool][source]
Implement this function to return the LossModule for a specific group.
- Parameters:
Returns: LossModule and a bool representing if the loss should have target parameters
- abstract _get_parameters(group: str, loss: LossModule) Dict[str, Iterable][source]
Get the dictionary mapping loss names to the relative parameters to optimize for a given group loss.
Returns: a dictionary mapping loss names to a parameters’ list
- abstract _get_policy_for_loss(group: str, model_config: ModelConfig, continuous: bool) TensorDictModule[source]
Get the non-explorative policy for a specific group.
- Parameters:
group (str) – agent group of the policy
model_config (ModelConfig) – model config class
continuous (bool) – whether the policy should be continuous or discrete
Returns: TensorDictModule representing the policy
- abstract _get_policy_for_collection(policy_for_loss: TensorDictModule, group: str, continuous: bool) TensorDictModule[source]
Implement this function to add an explorative layer to the policy used in the loss.
- Parameters:
Returns: TensorDictModule representing the explorative policy
- abstract process_batch(group: str, batch: TensorDictBase) TensorDictBase[source]
This function can be used to reshape data coming from collection before it is passed to the policy.
- Parameters:
group (str) – agent group
batch (TensorDictBase) – the batch of data coming from the collector
Returns: the processed batch
- process_loss_vals(group: str, loss_vals: TensorDictBase) TensorDictBase[source]
Here you can modify the loss_vals tensordict containing entries loss_name->loss_value For example, you can sum two entries in a new entry, to optimize them together.
- Parameters:
group (str) – agent group
loss_vals (TensorDictBase) – the tensordict returned by the loss forward method
Returns: the processed loss_vals
- _abc_impl = <_abc._abc_data object>