# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
import copy
import importlib
import os
import pickle
import shutil
import time
import warnings
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector
from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm
from benchmarl.algorithms import IppoConfig, MappoConfig
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, TaskClass
from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models import GnnConfig, SequenceModelConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import (
_add_rnn_transforms,
_read_yaml_config,
local_seed,
seed_everything,
)
_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
from hydra.core.hydra_config import HydraConfig
[docs]
@dataclass
class ExperimentConfig:
"""
Configuration class for experiments.
This class acts as a schema for loading and validating yaml configurations.
Parameters in this class aim to be agnostic of the algorithm, task or model used.
To know their meaning, please check out the descriptions in ``benchmarl/conf/experiment/base_experiment.yaml``
"""
sampling_device: str = MISSING
train_device: str = MISSING
buffer_device: str = MISSING
share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING
parallel_collection: bool = MISSING
gamma: float = MISSING
lr: float = MISSING
adam_eps: float = MISSING
adam_extra_kwargs: Dict[str, Any] = MISSING
clip_grad_norm: bool = MISSING
clip_grad_val: Optional[float] = MISSING
soft_target_update: bool = MISSING
polyak_tau: float = MISSING
hard_target_update_frequency: int = MISSING
exploration_eps_init: float = MISSING
exploration_eps_end: float = MISSING
exploration_anneal_frames: Optional[int] = MISSING
max_n_iters: Optional[int] = MISSING
max_n_frames: Optional[int] = MISSING
on_policy_collected_frames_per_batch: int = MISSING
on_policy_n_envs_per_worker: int = MISSING
on_policy_n_minibatch_iters: int = MISSING
on_policy_minibatch_size: int = MISSING
off_policy_collected_frames_per_batch: int = MISSING
off_policy_n_envs_per_worker: int = MISSING
off_policy_n_optimizer_steps: int = MISSING
off_policy_train_batch_size: int = MISSING
off_policy_memory_size: int = MISSING
off_policy_init_random_frames: int = MISSING
off_policy_use_prioritized_replay_buffer: bool = MISSING
off_policy_prb_alpha: float = MISSING
off_policy_prb_beta: float = MISSING
evaluation: bool = MISSING
render: bool = MISSING
evaluation_interval: int = MISSING
evaluation_episodes: int = MISSING
evaluation_deterministic_actions: bool = MISSING
evaluation_static: bool = MISSING
loggers: List[str] = MISSING
project_name: str = MISSING
wandb_extra_kwargs: Dict[str, Any] = MISSING
create_json: bool = MISSING
save_folder: Optional[str] = MISSING
restore_file: Optional[str] = MISSING
restore_map_location: Optional[Any] = MISSING
checkpoint_interval: int = MISSING
checkpoint_at_end: bool = MISSING
keep_checkpoints_num: Optional[int] = MISSING
exclude_buffer_from_checkpoint: bool = MISSING
[docs]
def train_batch_size(self, on_policy: bool) -> int:
"""
The batch size of tensors used for training
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.collected_frames_per_batch(on_policy)
if on_policy
else self.off_policy_train_batch_size
)
[docs]
def train_minibatch_size(self, on_policy: bool) -> int:
"""
The minibatch size of tensors used for training.
On-policy algorithms are trained by splitting the train_batch_size (equal to the collected frames) into minibatches.
Off-policy algorithms do not go through this process and thus have the ``train_minibatch_size==train_batch_size``
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_minibatch_size
if on_policy
else self.train_batch_size(on_policy)
)
[docs]
def n_optimizer_steps(self, on_policy: bool) -> int:
"""
Number of times to loop over the training step per collection iteration.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_n_minibatch_iters
if on_policy
else self.off_policy_n_optimizer_steps
)
[docs]
def replay_buffer_memory_size(self, on_policy: bool) -> int:
"""
Size of the replay buffer memory in terms of frames
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.collected_frames_per_batch(on_policy)
if on_policy
else self.off_policy_memory_size
)
[docs]
def collected_frames_per_batch(self, on_policy: bool) -> int:
"""
Number of collected frames per collection iteration.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_collected_frames_per_batch
if on_policy
else self.off_policy_collected_frames_per_batch
)
[docs]
def n_envs_per_worker(self, on_policy: bool) -> int:
"""
Number of environments used for collection
- In vectorized environments, this will be the vectorized batch_size.
- In other environments, this will be emulated by running them sequentially.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_n_envs_per_worker
if on_policy
else self.off_policy_n_envs_per_worker
)
[docs]
def get_max_n_frames(self, on_policy: bool) -> int:
"""
Get the maximum number of frames collected before the experiment ends.
Args:
on_policy (bool): is the algorithms on_policy
"""
if self.max_n_frames is not None and self.max_n_iters is not None:
return min(
self.max_n_frames,
self.max_n_iters * self.collected_frames_per_batch(on_policy),
)
elif self.max_n_frames is not None:
return self.max_n_frames
elif self.max_n_iters is not None:
return self.max_n_iters * self.collected_frames_per_batch(on_policy)
[docs]
def get_max_n_iters(self, on_policy: bool) -> int:
"""
Get the maximum number of experiment iterations before the experiment ends.
Args:
on_policy (bool): is the algorithms on_policy
"""
return -(
-self.get_max_n_frames(on_policy)
// self.collected_frames_per_batch(on_policy)
)
[docs]
def get_exploration_anneal_frames(self, on_policy: bool):
"""
Get the number of frames for exploration annealing.
If self.exploration_anneal_frames is None this will be a third of the total frames to collect.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
(self.get_max_n_frames(on_policy) // 3)
if self.exploration_anneal_frames is None
else self.exploration_anneal_frames
)
[docs]
@staticmethod
def get_from_yaml(path: Optional[str] = None):
"""
Load the experiment configuration from yaml
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
``benchmarl/conf/experiment/base_experiment.yaml``
Returns:
the loaded :class:`~benchmarl.experiment.ExperimentConfig`
"""
if path is None:
yaml_path = (
Path(__file__).parent.parent
/ "conf"
/ "experiment"
/ "base_experiment.yaml"
)
return ExperimentConfig(**_read_yaml_config(str(yaml_path.resolve())))
else:
return ExperimentConfig(**_read_yaml_config(path))
[docs]
def validate(self, on_policy: bool):
"""
Validates config.
Args:
on_policy (bool): is the algorithms on_policy
"""
if (
self.evaluation
and self.evaluation_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"evaluation_interval ({self.evaluation_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if (
self.checkpoint_interval != 0
and self.checkpoint_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
raise ValueError("keep_checkpoints_num must be greater than zero or null")
if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("max_n_frames and max_n_iters are both not set")
if self.max_n_frames is not None and self.max_n_iters is not None:
warnings.warn(
f"max_n_frames and max_n_iters have both been set. The experiment will terminate after "
f"{self.get_max_n_iters(on_policy)} iterations ({self.get_max_n_frames(on_policy)} frames)."
)
[docs]
class Experiment(CallbackNotifier):
"""
Main experiment class in BenchMARL.
Args:
task (TaskClass): the task
algorithm_config (AlgorithmConfig): the algorithm configuration
model_config (ModelConfig): the policy model configuration
seed (int): the seed for the experiment
config (ExperimentConfig): The experiment config. Note that some of the parameters
of this config may go un-consumed based on the provided algorithm or model config.
For example, all parameters off-policy algorithm would not be used when running
an experiment with an on-policy algorithm.
critic_model_config (ModelConfig, optional): the policy model configuration.
If None, it defaults to model_config
callbacks (list of Callback, optional): callbacks for this experiment
"""
def __init__(
self,
task: Union[Task, TaskClass],
algorithm_config: AlgorithmConfig,
model_config: ModelConfig,
seed: int,
config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
callbacks: Optional[List[Callback]] = None,
):
super().__init__(
experiment=self, callbacks=callbacks if callbacks is not None else []
)
self.config = config
if isinstance(task, Task):
warnings.warn(
"Call `.get_task()` or `.get_from_yaml()` on your task Enum before passing it to the experiment. "
"If you do not do this, benchmarl will load the default task config from yaml."
)
task = task.get_task()
self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config
if critic_model_config is not None
else copy.deepcopy(model_config)
)
self.critic_model_config.is_critic = True
self.algorithm_config = algorithm_config
self.seed = seed
self._setup()
self.total_time = 0
self.total_frames = 0
self.n_iters_performed = 0
self.mean_return = 0
if self.config.restore_file is not None:
self._load_experiment()
@property
def on_policy(self) -> bool:
"""Whether the algorithm has to be run on policy."""
return self.algorithm_config.on_policy()
def _setup(self):
self.config.validate(self.on_policy)
seed_everything(self.seed)
self._perform_checks()
self._set_action_type()
self._setup_name()
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_logger()
self._on_setup()
def _perform_checks(self):
for config in (self.model_config, self.critic_model_config):
if isinstance(config, SequenceModelConfig):
for layer_config in config.model_configs[1:]:
if isinstance(layer_config, GnnConfig) and (
layer_config.position_key is not None
or layer_config.velocity_key is not None
):
raise ValueError(
"GNNs reading position or velocity keys are currently only usable in first"
" layer of sequence models"
)
if self.algorithm_config in (MappoConfig, IppoConfig):
critic_model_config = self.critic_model_config
if isinstance(critic_model_config, SequenceModelConfig):
critic_model_config = self.critic_model_config.model_configs[0]
if (
isinstance(critic_model_config, GnnConfig)
and critic_model_config.topology == "from_pos"
):
raise ValueError(
"GNNs in PPO critics with topology 'from_pos' are currently not available, "
"see https://github.com/pytorch/rl/issues/2537"
)
def _set_action_type(self):
if (
self.task.supports_continuous_actions()
and self.algorithm_config.supports_continuous_actions()
and self.config.prefer_continuous_actions
):
self.continuous_actions = True
elif (
self.task.supports_discrete_actions()
and self.algorithm_config.supports_discrete_actions()
):
self.continuous_actions = False
elif (
self.task.supports_continuous_actions()
and self.algorithm_config.supports_continuous_actions()
):
self.continuous_actions = True
else:
raise ValueError(
f"Algorithm {self.algorithm_config} is not compatible"
f" with the action space of task {self.task} "
)
def _setup_task(self):
test_env = self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)()
env_func = self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
transforms_env = self.task.get_env_transforms(test_env)
transforms_training = transforms_env + [
self.task.get_reward_sum_transform(test_env)
]
transforms_env = Compose(*transforms_env)
transforms_training = Compose(*transforms_training)
# Initialize test env
self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
self.config.sampling_device
)
self.observation_spec = self.task.observation_spec(self.test_env)
self.info_spec = self.task.info_spec(self.test_env)
self.state_spec = self.task.state_spec(self.test_env)
self.action_mask_spec = self.task.action_mask_spec(self.test_env)
self.action_spec = self.task.action_spec(self.test_env)
self.group_map = self.task.group_map(self.test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(self.test_env)
# Add rnn transforms here so they do not show in the benchmarl specs
if self.model_config.is_rnn:
self.test_env = _add_rnn_transforms(
lambda: self.test_env, self.group_map, self.model_config
)()
env_func = _add_rnn_transforms(env_func, self.group_map, self.model_config)
# Initialize train env
if self.test_env.batch_size == ():
# If the environment is not vectorized, we simulate vectorization using parallel or serial environments
env_class = (
SerialEnv if not self.config.parallel_collection else ParallelEnv
)
self.env_func = lambda: TransformedEnv(
env_class(self.config.n_envs_per_worker(self.on_policy), env_func),
transforms_training.clone(),
)
else:
# Otherwise it is already vectorized
self.env_func = lambda: TransformedEnv(
env_func(), transforms_training.clone()
)
def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
self.losses = {
group: self.algorithm.get_loss_and_updater(group)[0]
for group in self.group_map.keys()
}
self.target_updaters = {
group: self.algorithm.get_loss_and_updater(group)[1]
for group in self.group_map.keys()
}
self.optimizers = {
group: {
loss_name: torch.optim.Adam(
params,
lr=self.config.lr,
eps=self.config.adam_eps,
**self.config.adam_extra_kwargs,
)
for loss_name, params in self.algorithm.get_parameters(group).items()
}
for group in self.group_map.keys()
}
def _setup_collector(self):
self.policy = self.algorithm.get_policy_for_collection()
self.group_policies = {}
for group in self.group_map.keys():
group_policy = self.policy.select_subsequence(out_keys=[(group, "action")])
assert len(group_policy) == 1
self.group_policies.update({group: group_policy[0]})
if not self.config.collect_with_grad:
self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.sampling_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames
if not self.on_policy
else 0
),
)
else:
if self.config.off_policy_init_random_frames and not self.on_policy:
raise TypeError(
"Collection via rollouts does not support initial random frames as of now."
)
self.rollout_env = self.env_func().to(self.config.sampling_device)
def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
self.model_name = self.model_config.associated_class().__name__.lower()
self.critic_model_name = (
self.critic_model_config.associated_class().__name__.lower()
)
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()
self._checkpointed_files = deque([])
if self.config.save_folder is not None:
# If the user specified a folder for the experiment we use that
save_folder = Path(self.config.save_folder)
else:
# Otherwise, if the user is restoring from a folder, we will save in the folder they are restoring from
if self.config.restore_file is not None:
save_folder = Path(
self.config.restore_file
).parent.parent.parent.resolve()
# Otherwise, the user is not restoring and did not specify a save_folder so we save in the hydra directory
# of the experiment or in the directory where the experiment was run (if hydra is not used)
else:
if _has_hydra and HydraConfig.initialized():
save_folder = Path(HydraConfig.get().runtime.output_dir)
else:
save_folder = Path(os.getcwd())
if self.config.restore_file is None:
self.name = generate_exp_name(
f"{self.algorithm_name}_{self.task_name}_{self.model_name}", ""
)
self.folder_name = save_folder / self.name
else:
# If restoring, we use the name of the previous experiment
self.name = Path(self.config.restore_file).parent.parent.resolve().name
self.folder_name = save_folder / self.name
self.folder_name.mkdir(parents=False, exist_ok=True)
with open(self.folder_name / "config.pkl", "wb") as f:
pickle.dump(self.task, f)
pickle.dump(self.task.config if self.task.config is not None else {}, f)
pickle.dump(self.algorithm_config, f)
pickle.dump(self.model_config, f)
pickle.dump(self.seed, f)
pickle.dump(self.config, f)
pickle.dump(self.critic_model_config, f)
pickle.dump(self.callbacks, f)
def _setup_logger(self):
hparams_kwargs = {
"critic_model_name": self.critic_model_name,
"experiment_config": self.config.__dict__,
"algorithm_config": self.algorithm_config.__dict__,
"model_config": self.model_config.__dict__,
"critic_model_config": self.critic_model_config.__dict__,
"task_config": self.task.config,
"continuous_actions": self.continuous_actions,
"on_policy": self.on_policy,
"algorithm_name": self.algorithm_name,
"model_name": self.model_name,
"task_name": self.task_name,
"environment_name": self.environment_name,
"seed": self.seed,
}
self.logger = Logger(
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
algorithm_name=self.algorithm_name,
model_name=self.model_name,
environment_name=self.environment_name,
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
project_name=self.config.project_name,
wandb_extra_kwargs={
**self.config.wandb_extra_kwargs,
"config": hparams_kwargs,
},
)
self.logger.log_hparams(**hparams_kwargs)
[docs]
def run(self):
"""Run the experiment until completion."""
try:
seed_everything(self.seed)
torch.cuda.empty_cache()
self._collection_loop()
except KeyboardInterrupt as interrupt:
print("\n\nExperiment was closed gracefully\n\n")
self.close()
raise interrupt
except Exception as err:
print("\n\nExperiment failed and is closing gracefully\n\n")
self.close()
raise err
[docs]
def evaluate(self):
"""Run just the evaluation loop once."""
seed_everything(self.seed)
self._evaluation_loop()
self.logger.commit()
print(
f"Evaluation results logged to loggers={self.config.loggers}"
f"{' and to a json file in the experiment folder.' if self.config.create_json else ''}"
)
def _collection_loop(self):
pbar = tqdm(
initial=self.n_iters_performed,
total=self.config.get_max_n_iters(self.on_policy),
)
if not self.config.collect_with_grad:
iterator = iter(self.collector)
else:
reset_batch = self.rollout_env.reset()
# Training/collection iterations
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
iteration_start = time.time()
if not self.config.collect_with_grad:
batch = next(iterator)
else:
with set_exploration_type(ExplorationType.RANDOM):
batch = self.rollout_env.rollout(
max_steps=-(
-self.config.collected_frames_per_batch(self.on_policy)
// self.rollout_env.batch_size.numel()
),
policy=self.policy,
auto_cast_to_device=True,
break_when_any_done=False,
auto_reset=False,
tensordict=reset_batch,
)
reset_batch = step_mdp(
batch[..., -1],
reward_keys=self.rollout_env.reward_keys,
action_keys=self.rollout_env.action_keys,
done_keys=self.rollout_env.done_keys,
)
# Logging collection
collection_time = time.time() - iteration_start
current_frames = batch.numel()
self.total_frames += current_frames
self.mean_return = self.logger.log_collection(
batch,
total_frames=self.total_frames,
task=self.task,
step=self.n_iters_performed,
)
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
# Callback
self._on_batch_collected(batch)
batch = batch.detach()
# Loop over groups
training_start = time.time()
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group)).to(
self.config.train_device
)
group_batch = self.algorithm.process_batch(group, group_batch)
if not self.algorithm.has_rnn:
group_batch = group_batch.reshape(-1)
group_buffer = self.replay_buffers[group]
group_buffer.extend(group_batch.to(group_buffer.storage.device))
training_tds = []
for _ in range(self.config.n_optimizer_steps(self.on_policy)):
for _ in range(
-(
-self.config.train_batch_size(self.on_policy)
// self.config.train_minibatch_size(self.on_policy)
)
):
training_tds.append(self._optimizer_loop(group))
training_td = torch.stack(training_tds)
self.logger.log_training(
group, training_td, step=self.n_iters_performed
)
# Callback
self._on_train_end(training_td, group)
# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
explore_layer = self.group_policies[group][-1]
else:
explore_layer = self.group_policies[group]
if hasattr(explore_layer, "step"): # Step exploration annealing
explore_layer.step(current_frames)
# Update policy in collector
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()
# Training timer
training_time = time.time() - training_start
# Evaluation
if (
self.config.evaluation
and (
self.total_frames % self.config.evaluation_interval == 0
or self.n_iters_performed == 0
)
and (len(self.config.loggers) or self.config.create_json)
):
self._evaluation_loop()
# End of step
iteration_time = time.time() - iteration_start
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)
self.n_iters_performed += 1
self.logger.commit()
if (
self.config.checkpoint_interval > 0
and self.total_frames % self.config.checkpoint_interval == 0
):
self._save_experiment()
pbar.update()
if self.config.checkpoint_at_end:
self._save_experiment()
self.close()
[docs]
def close(self):
"""Close the experiment."""
if not self.config.collect_with_grad:
self.collector.shutdown()
else:
self.rollout_env.close()
self.test_env.close()
self.logger.finish()
for buffer in self.replay_buffers.values():
if hasattr(buffer.storage, "scratch_dir"):
shutil.rmtree(buffer.storage.scratch_dir, ignore_errors=False)
def _get_excluded_keys(self, group: str):
excluded_keys = []
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += ["info", (group, "info"), ("next", group, "info")]
return excluded_keys
def _optimizer_loop(self, group: str) -> TensorDictBase:
subdata = self.replay_buffers[group].sample().to(self.config.train_device)
loss_vals = self.losses[group](subdata)
training_td = loss_vals.detach()
loss_vals = self.algorithm.process_loss_vals(group, loss_vals)
for loss_name, loss_value in loss_vals.items():
if loss_name in self.optimizers[group].keys():
optimizer = self.optimizers[group][loss_name]
loss_value.backward()
grad_norm = self._grad_clip(optimizer)
training_td.set(
f"grad_norm_{loss_name}",
torch.tensor(grad_norm, device=self.config.train_device),
)
optimizer.step()
optimizer.zero_grad()
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
callback_loss = self._on_train_step(subdata, group)
if callback_loss is not None:
training_td.update(callback_loss)
return training_td
def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
params = []
for param_group in optimizer.param_groups:
params += param_group["params"]
if self.config.clip_grad_norm and self.config.clip_grad_val is not None:
total_norm = torch.nn.utils.clip_grad_norm_(
params, self.config.clip_grad_val
)
else:
norm_type = 2.0
norms = [
torch.linalg.vector_norm(p.grad, norm_type)
for p in params
if p.grad is not None
]
total_norm = torch.linalg.vector_norm(torch.stack(norms), norm_type)
if self.config.clip_grad_val is not None:
torch.nn.utils.clip_grad_value_(params, self.config.clip_grad_val)
return float(total_norm)
@local_seed()
@torch.no_grad()
def _evaluation_loop(self):
if self.config.evaluation_static:
seed_everything(self.seed)
try:
self.test_env.set_seed(self.seed)
except NotImplementedError:
warnings.warn(
"`experiment.evaluation_static` set to true but the environment does not allow to set seeds."
"Static evaluation is not guaranteed."
)
evaluation_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
if self.config.evaluation_deterministic_actions
else ExplorationType.RANDOM
):
if self.task.has_render(self.test_env) and self.config.render:
video_frames = []
def callback(env, td):
video_frames.append(
self.task.__class__.render_callback(self, env, td)
)
else:
video_frames = None
callback = None
if self.test_env.batch_size == ():
rollouts = []
for eval_episode in range(self.config.evaluation_episodes):
rollouts.append(
self.test_env.rollout(
max_steps=self.max_steps,
policy=self.policy,
callback=callback if eval_episode == 0 else None,
auto_cast_to_device=True,
break_when_any_done=True,
)
)
else:
rollouts = self.test_env.rollout(
max_steps=self.max_steps,
policy=self.policy,
callback=callback,
auto_cast_to_device=True,
break_when_any_done=False,
# We are running vectorized evaluation we do not want it to stop when just one env is done
)
rollouts = list(rollouts.unbind(0))
evaluation_time = time.time() - evaluation_start
self.logger.log(
{"timers/evaluation_time": evaluation_time}, step=self.n_iters_performed
)
self.logger.log_evaluation(
rollouts,
video_frames=video_frames,
step=self.n_iters_performed,
total_frames=self.total_frames,
)
# Callback
self._on_evaluation_end(rollouts)
# Saving experiment state
[docs]
def state_dict(self) -> OrderedDict:
"""Get the state_dict for the experiment."""
state = OrderedDict(
total_time=self.total_time,
total_frames=self.total_frames,
n_iters_performed=self.n_iters_performed,
mean_return=self.mean_return,
)
state_dict = OrderedDict(
state=state,
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict()
if len(item) and not self.config.exclude_buffer_from_checkpoint
else None
for k, item in self.replay_buffers.items()
},
)
if not self.config.collect_with_grad:
state_dict.update({"collector": self.collector.state_dict()})
self._on_state_dict(state_dict)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the state_dict for the experiment.
Args:
state_dict (dict): the state dict
"""
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
if state_dict[f"buffer_{group}"] is not None:
self.replay_buffers[group].load_state_dict(
state_dict[f"buffer_{group}"]
)
if not self.config.collect_with_grad:
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
self.total_frames = state_dict["state"]["total_frames"]
self.n_iters_performed = state_dict["state"]["n_iters_performed"]
self.mean_return = state_dict["state"]["mean_return"]
self._on_load_state_dict(state_dict)
def _save_experiment(self) -> None:
"""Checkpoint trainer"""
if self.config.keep_checkpoints_num is not None:
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
file_to_delete = self._checkpointed_files.popleft()
file_to_delete.unlink(missing_ok=False)
checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
torch.save(self.state_dict(), checkpoint_file)
self._checkpointed_files.append(checkpoint_file)
def _load_experiment(self) -> Experiment:
"""Load trainer from checkpoint"""
loaded_dict: OrderedDict = torch.load(
self.config.restore_file, map_location=self.config.restore_map_location
)
self.load_state_dict(loaded_dict)
return self
[docs]
@staticmethod
def reload_from_file(
restore_file: str, experiment_patch: Optional[Dict[str, Any]] = None
) -> Experiment:
"""
Restores the experiment from the checkpoint file.
This method expects the same folder structure created when an experiment is run.
The checkpoint file (``restore_file``) is in the checkpoints directory and a config.pkl file is
present a level above at restore_file/../../config.pkl
Args:
restore_file (str): The checkpoint file (.pt) of the experiment reload.
experiment_patch (Optional[Dict[str, Any]]): The patch to apply to the experiment config.
Returns:
The reloaded experiment.
"""
experiment_folder = Path(restore_file).parent.parent.resolve()
config_file = experiment_folder / "config.pkl"
if not os.path.exists(config_file):
raise ValueError("config.pkl file not found in experiment folder.")
with open(config_file, "rb") as f:
task = pickle.load(f)
task_config = pickle.load(f)
algorithm_config = pickle.load(f)
model_config = pickle.load(f)
seed = pickle.load(f)
experiment_config = pickle.load(f)
critic_model_config = pickle.load(f)
callbacks = pickle.load(f)
task.config = task_config
experiment_config.restore_file = restore_file
if experiment_patch is not None:
for key, value in experiment_patch.items():
if not hasattr(experiment_config, key):
raise ValueError(f"Experiment config does not have attribute {key}")
setattr(experiment_config, key, value)
experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
seed=seed,
config=experiment_config,
callbacks=callbacks,
critic_model_config=critic_model_config,
)
print(f"\nReloaded experiment {experiment.name} from {restore_file}.")
return experiment