Source code for benchmarl.benchmark.benchmark

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from typing import Iterator, Optional, Sequence, Set

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.common import ModelConfig


[docs] class Benchmark: """A benchmark. Benchmarks are collections of experiments to compare. Args: algorithm_configs (list of AlgorithmConfig): the algorithms to benchmark model_config (ModelConfig): the config of the policy model tasks (list of Task): the tasks to benchmark seeds (set of int): the seeds for the benchmark experiment_config (ExperimentConfig): the experiment config critic_model_config (ModelConfig, optional): the config of the critic model. Defaults to model_config """ def __init__( self, algorithm_configs: Sequence[AlgorithmConfig], model_config: ModelConfig, tasks: Sequence[Task], seeds: Set[int], experiment_config: ExperimentConfig, critic_model_config: Optional[ModelConfig] = None, ): self.algorithm_configs = algorithm_configs self.tasks = tasks self.seeds = seeds self.model_config = model_config self.critic_model_config = ( critic_model_config if critic_model_config is not None else model_config ) self.experiment_config = experiment_config print(f"Created benchmark with {self.n_experiments} experiments.") @property def n_experiments(self): """The number of experiments in the benchmark.""" return len(self.algorithm_configs) * len(self.tasks) * len(self.seeds)
[docs] def get_experiments(self) -> Iterator[Experiment]: """Yields one experiment at a time""" for algorithm_config in self.algorithm_configs: for task in self.tasks: for seed in self.seeds: yield Experiment( task=task, algorithm_config=algorithm_config, seed=seed, model_config=self.model_config, critic_model_config=self.critic_model_config, config=self.experiment_config, )
[docs] def run_sequential(self): """Run all the experiments in the benchmark in a sequence.""" for i, experiment in enumerate(self.get_experiments()): print(f"\nRunning experiment {i+1}/{self.n_experiments}.\n") try: experiment.run() except KeyboardInterrupt as interrupt: print("\n\nBenchmark was closed gracefully\n\n") experiment.close() raise interrupt