# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import importlib
from dataclasses import is_dataclass
from pathlib import Path
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import task_config_registry, TaskClass
from benchmarl.environments.common import _type_check_task_config
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig
_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
from hydra import compose, initialize, initialize_config_dir
from omegaconf import DictConfig, OmegaConf
class _HydraMissingMetadataError(FileNotFoundError):
def __init__(
self,
message=".hydra folder not found (should be max 3 levels above checkpoint file",
):
self.message = message
super().__init__(self.message)
[docs]
def load_experiment_from_hydra(
cfg: DictConfig, task_name: str, callbacks=()
) -> Experiment:
"""Creates an :class:`~benchmarl.experiment.Experiment` from hydra config.
Args:
cfg (DictConfig): the config dictionary from hydra main
task_name (str): the name of the task to load
Returns:
:class:`~benchmarl.experiment.Experiment`
"""
algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
experiment_config = load_experiment_config_from_hydra(cfg.experiment)
task_config = load_task_config_from_hydra(cfg.task, task_name)
model_config = load_model_config_from_hydra(cfg.model)
critic_model_config = load_model_config_from_hydra(cfg.critic_model)
return Experiment(
task=task_config,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=cfg.seed,
config=experiment_config,
callbacks=callbacks,
)
[docs]
def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> TaskClass:
"""Returns a :class:`~benchmarl.environments.Task` from hydra config.
Args:
cfg (DictConfig): the task config dictionary from hydra
task_name (str): the name of the task to load
Returns:
:class:`~benchmarl.environments.Task`
"""
environment_name, inner_task_name = task_name.split("/")
cfg_dict_checked = OmegaConf.to_object(cfg)
if is_dataclass(cfg_dict_checked):
cfg_dict_checked = cfg_dict_checked.__dict__
cfg_dict_checked = _type_check_task_config(
environment_name, inner_task_name, cfg_dict_checked
) # Only needed for the warning
return task_config_registry[task_name].get_task(cfg_dict_checked)
[docs]
def load_experiment_config_from_hydra(cfg: DictConfig) -> ExperimentConfig:
"""Returns a :class:`~benchmarl.experiment.ExperimentConfig` from hydra config.
Args:
cfg (DictConfig): the experiment config dictionary from hydra
Returns:
:class:`~benchmarl.experiment.ExperimentConfig`
"""
return OmegaConf.to_object(cfg)
[docs]
def load_algorithm_config_from_hydra(cfg: DictConfig) -> AlgorithmConfig:
"""Returns a :class:`~benchmarl.algorithms.AlgorithmConfig` from hydra config.
Args:
cfg (DictConfig): the algorithm config dictionary from hydra
Returns:
:class:`~benchmarl.algorithms.AlgorithmConfig`
"""
return OmegaConf.to_object(cfg)
[docs]
def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
"""Returns a :class:`~benchmarl.models.ModelConfig` from hydra config.
Args:
cfg (DictConfig): the model config dictionary from hydra
Returns:
:class:`~benchmarl.models.ModelConfig`
"""
if "layers" in cfg.keys():
model_configs = [
load_model_config_from_hydra(cfg.layers[f"l{i}"])
for i in range(1, len(cfg.layers) + 1)
]
return SequenceModelConfig(
model_configs=model_configs, intermediate_sizes=cfg.intermediate_sizes
)
else:
model_class = model_config_registry[cfg.name]
return model_class(
**parse_model_config(
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
)
def _find_hydra_folder(restore_file: str) -> str:
"""Given the restore file, look for the .hydra folder max three levels above it."""
current_folder = Path(restore_file).parent.resolve()
for _ in range(3):
hydra_dir = current_folder / ".hydra"
if hydra_dir.exists() and hydra_dir.is_dir():
return str(hydra_dir)
current_folder = current_folder.parent
raise _HydraMissingMetadataError()
[docs]
def reload_experiment_from_file(restore_file: str) -> Experiment:
"""Reloads the experiment from a given restore file.
Requires a ``.hydra`` folder containing ``config.yaml``, ``hydra.yaml``, and ``overrides.yaml``
at max three directory levels higher than the checkpoint file. This should be automatically created by hydra.
Args:
restore_file (str): The checkpoint file of the experiment reload.
"""
try:
hydra_folder = _find_hydra_folder(restore_file)
except _HydraMissingMetadataError:
# Hydra was not used
return Experiment.reload_from_file(restore_file)
with initialize(
version_base=None,
config_path="conf",
):
cfg = compose(
config_name="config",
overrides=OmegaConf.load(Path(hydra_folder) / "overrides.yaml"),
return_hydra_config=True,
)
task_name = cfg.hydra.runtime.choices.task
algorithm_name = cfg.hydra.runtime.choices.algorithm
with initialize_config_dir(version_base=None, config_dir=hydra_folder):
cfg_loaded = dict(compose(config_name="config"))
for key in ("experiment", "algorithm", "task", "model", "critic_model"):
cfg[key].update(cfg_loaded[key])
cfg_loaded.pop(key)
cfg.update(cfg_loaded)
del cfg.hydra
cfg.experiment.restore_file = restore_file
print("\nReloaded experiment with:")
print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))
return load_experiment_from_hydra(cfg, task_name=task_name)