# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
from typing import Any, Dict, List
from tensordict import TensorDictBase
[docs]
class Callback:
"""
A Callback that can be added to experiments.
To create your callback, you can inherit from this class
and reimplement just the functions you need.
Attributes:
experiment (Experiment): the experiment associated to the callback
"""
def __init__(self):
self.experiment = None
[docs]
def on_setup(self):
"""A callback called at experiment setup."""
pass
[docs]
def on_load_state_dict(self, state_dict: Dict[str, Any]):
"""A callback called at state_dict load."""
pass
[docs]
def on_batch_collected(self, batch: TensorDictBase):
"""
A callback called at the end of every collection step.
Args:
batch (TensorDictBase): batch of collected data
"""
pass
[docs]
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
"""
A callback called for every training step.
Args:
batch (TensorDictBase): tensordict with the training batch
group (str): group name
Returns:
TensorDictBase: a new tensordict containing the loss values
"""
pass
[docs]
def on_train_end(self, training_td: TensorDictBase, group: str):
"""
A callback called at the end of training.
Args:
training_td (TensorDictBase): tensordict containing the loss values
group (str): group name
"""
pass
[docs]
def on_evaluation_end(self, rollouts: List[TensorDictBase]):
"""
A callback called at the end of every training step.
Args:
rollouts (list of TensorDictBase): tensordict containing the loss values
"""
pass
[docs]
def on_state_dict(self, state_dict: Dict[str, Any]):
"""A callback called at state_dict save."""
pass
class CallbackNotifier:
def __init__(self, experiment, callbacks: List[Callback]):
self.callbacks = callbacks
for callback in self.callbacks:
callback.experiment = experiment
def _on_setup(self):
for callback in self.callbacks:
callback.on_setup()
def _on_load_state_dict(self, state_dict: Dict[str, Any]):
for callback in self.callbacks:
callback.on_load_state_dict(state_dict)
def _on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)
def _on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td = None
for callback in self.callbacks:
td = callback.on_train_step(batch, group)
if td is not None:
if train_td is None:
train_td = td
else:
train_td.update(td)
return train_td
def _on_train_end(self, training_td: TensorDictBase, group: str):
for callback in self.callbacks:
callback.on_train_end(training_td, group)
def _on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
callback.on_evaluation_end(rollouts)
def _on_state_dict(self, state_dict: Dict[str, Any]):
for callback in self.callbacks:
callback.on_state_dict(state_dict)