Source code for benchmarl.experiment.callback

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

from typing import Any, Dict, List

from tensordict import TensorDictBase


[docs] class Callback: """ A Callback that can be added to experiments. To create your callback, you can inherit from this class and reimplement just the functions you need. Attributes: experiment (Experiment): the experiment associated to the callback """ def __init__(self): self.experiment = None
[docs] def on_setup(self): """A callback called at experiment setup.""" pass
[docs] def on_load_state_dict(self, state_dict: Dict[str, Any]): """A callback called at state_dict load.""" pass
[docs] def on_batch_collected(self, batch: TensorDictBase): """ A callback called at the end of every collection step. Args: batch (TensorDictBase): batch of collected data """ pass
[docs] def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase: """ A callback called for every training step. Args: batch (TensorDictBase): tensordict with the training batch group (str): group name Returns: TensorDictBase: a new tensordict containing the loss values """ pass
[docs] def on_train_end(self, training_td: TensorDictBase, group: str): """ A callback called at the end of training. Args: training_td (TensorDictBase): tensordict containing the loss values group (str): group name """ pass
[docs] def on_evaluation_end(self, rollouts: List[TensorDictBase]): """ A callback called at the end of every training step. Args: rollouts (list of TensorDictBase): tensordict containing the loss values """ pass
[docs] def on_state_dict(self, state_dict: Dict[str, Any]): """A callback called at state_dict save.""" pass
class CallbackNotifier: def __init__(self, experiment, callbacks: List[Callback]): self.callbacks = callbacks for callback in self.callbacks: callback.experiment = experiment def _on_setup(self): for callback in self.callbacks: callback.on_setup() def _on_load_state_dict(self, state_dict: Dict[str, Any]): for callback in self.callbacks: callback.on_load_state_dict(state_dict) def _on_batch_collected(self, batch: TensorDictBase): for callback in self.callbacks: callback.on_batch_collected(batch) def _on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase: train_td = None for callback in self.callbacks: td = callback.on_train_step(batch, group) if td is not None: if train_td is None: train_td = td else: train_td.update(td) return train_td def _on_train_end(self, training_td: TensorDictBase, group: str): for callback in self.callbacks: callback.on_train_end(training_td, group) def _on_evaluation_end(self, rollouts: List[TensorDictBase]): for callback in self.callbacks: callback.on_evaluation_end(rollouts) def _on_state_dict(self, state_dict: Dict[str, Any]): for callback in self.callbacks: callback.on_state_dict(state_dict)