"""BayBE settings."""
from __future__ import annotations
import gc
import os
import random
import tempfile
import warnings
from copy import deepcopy
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
import numpy as np
from attrs import Attribute, Converter, Factory, define, field, fields
from attrs.setters import validate
from attrs.validators import in_, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import Self
from baybe._optional.info import FPSAMPLE_INSTALLED, POLARS_INSTALLED
from baybe.exceptions import OptionalImportError
from baybe.utils.basic import classproperty
from baybe.utils.boolean import AutoBool, strtobool
if TYPE_CHECKING:
from types import TracebackType
import torch
_TSeed = TypeVar("_TSeed", int, None)
# The temporary assignment to `None` is needed because the object is already referenced
# in the `Settings` class body
active_settings: Settings = None # type: ignore[assignment]
"""The global settings instance controlling execution behavior."""
_MISSING_PACKAGE_ERROR_MESSAGE = (
"The setting 'use_{package_name}' cannot be set to 'True' because '{package_name}' "
"is not installed. Either install '{package_name}' or set 'use_{package_name}' "
"to 'False'/'Auto'."
)
def _validate_whitelist_env_vars(vars: dict[str, str], /) -> None:
"""Validate the values of non-settings environment variables."""
if (value := vars.pop("BAYBE_TEST_ENV", None)) is not None:
if value not in {"CORETEST", "FULLTEST", "GPUTEST"}:
raise ValueError(
f"Allowed values for 'BAYBE_TEST_ENV' are "
f"'CORETEST', 'FULLTEST', and 'GPUTEST'. Given: '{value}'"
)
if vars:
raise RuntimeError(f"Unknown environment variables: {set(vars)}")
class _SlottedContextDecorator:
"""Like :class:`contextlib.ContextDecorator` but with `__slots__`.
The code has been copied from the Python standard library.
"""
__slots__ = ()
def _recreate_cm(self):
return self
def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
return func(*args, **kwds)
return inner
def _to_bool(value: Any) -> bool:
"""Convert Booleans and strings representing Booleans to actual Booleans."""
if isinstance(value, bool):
return value
if isinstance(value, str):
return strtobool(value)
raise TypeError(f"Cannot convert value of type '{type(value)}' to Boolean.")
[docs]
def adjust_defaults(cls: type[Settings], fields: list[Attribute]) -> list[Attribute]:
"""Replace default values with the appropriate source, controlled via flags."""
results = []
for fld in fields:
if fld.name in cls._internal_attributes:
results.append(fld)
continue
# We use a factory here because the environment variables should be lookup up
# at instantiation time, not at class definition time
def make_default_factory(fld: Attribute) -> Any:
# TODO: https://github.com/python-attrs/attrs/issues/1479
name = fld.alias or fld.name
def _(self: Settings) -> Any:
if self._restore_defaults:
default = fld.default
else:
# Here, the current global settings value is used as default, to
# enable updating settings one attribute at a time (the fallback to
# the default happens when the global settings object is itself
# being created)
default = getattr(active_settings, fld.name, fld.default)
if self._restore_environment:
# If enabled, the environment values take precedence for the default
env_name = f"BAYBE_{name.upper()}"
value = os.getenv(env_name, default)
if fld.type == "bool":
value = _to_bool(value)
return value
return default
return Factory(_, takes_self=True)
results.append(fld.evolve(default=make_default_factory(fld)))
return results
@define(frozen=True)
class _RandomState:
"""Container for the random states of all managed numeric libraries."""
state_builtin = field(factory=random.getstate)
"""The state of the built-in random number generator."""
state_numpy = field(factory=np.random.get_state)
"""The state of the Numpy random number generator."""
state_torch = field() # set by default method below (for lazy torch loading)
"""The state of the Torch random number generator."""
@state_torch.default
def _default_state_torch(self) -> Any:
"""Get the current Torch random state using a lazy import."""
import torch
return torch.get_rng_state()
def activate(self) -> None:
"""Activate the random state."""
import torch
random.setstate(self.state_builtin)
np.random.set_state(self.state_numpy)
torch.set_rng_state(self.state_torch)
@classmethod
def activate_from_seed(cls, seed: int) -> Self:
"""Active the random state corresponding to a given seed."""
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return cls()
def _on_set_random_seed(instance: Settings, __: Attribute, value: _TSeed) -> _TSeed:
"""Activate the given random seed on attribute change."""
if id(instance) == Settings._global_settings_id and value is not None:
_RandomState.activate_from_seed(value)
return value
def _convert_cache_directory(
value: str | Path | None, field: Attribute, /
) -> Path | None:
"""Attrs converter for the cache directory setting."""
if value is None or value == "":
return None
try:
return Path(value)
except Exception as ex:
raise type(ex)(
f"Cannot set '{field.alias}' to '{value}'. "
f"Expected 'None' or a path-like object."
) from ex
[docs]
@define(kw_only=True, field_transformer=adjust_defaults)
class Settings(_SlottedContextDecorator):
"""BayBE settings."""
# >>>>> Internal
_global_settings_id: ClassVar[int]
"""The id of the global settings instance.
Useful to identify if an action is performed on the global or a local instance."""
_previous_random_state: _RandomState | None = field(init=False, default=None)
"""The previously set random state."""
_previous_settings: Settings | None = field(default=None, init=False)
"""The previously active settings (used for context management)."""
# <<<<< Internal
# >>>>> Control flags
_restore_defaults: bool = field(default=False, validator=instance_of(bool))
"""Controls if settings shall be restored to their default values."""
_restore_environment: bool = field(default=False, validator=instance_of(bool))
"""Controls if environment variables shall be used to initialize settings."""
# <<<<< Control flags
# >>>>> Settings attributes
cache_campaign_recommendations: bool = field(
default=True, validator=instance_of(bool)
)
"""Controls if campaigns cache their latest recommendation."""
cache_directory: Path | None = field(
default=Path(tempfile.gettempdir()) / ".baybe_cache",
converter=Converter(_convert_cache_directory, takes_field=True), # type: ignore[misc]
)
"""The directory used for caching. Set to "" or ``None`` to disable caching."""
float_precision_numpy: int = field(
default=64, converter=int, validator=in_((16, 32, 64))
)
"""The floating point precision used for NumPy arrays."""
float_precision_torch: int = field(
default=64, converter=int, validator=in_((16, 32, 64))
)
"""The floating point precision used for Torch tensors."""
parallelize_simulations: bool = field(default=True, validator=instance_of(bool))
"""Controls if simulation runs in `xyzpy <https://xyzpy.readthedocs.io/en/latest/index.html>`_ are executed in parallel.""" # noqa: E501
preprocess_dataframes: bool = field(default=True, validator=instance_of(bool))
"""Controls if dataframe content is validated and normalized before used."""
random_seed: int | None = field(
default=None,
validator=optional_v(instance_of(int)),
on_setattr=[validate, _on_set_random_seed],
)
"""The used random seed."""
_use_fpsample: AutoBool = field(
alias="use_fpsample",
default=AutoBool.AUTO,
converter=AutoBool.from_unstructured, # type: ignore[misc]
)
"""Controls if `fpsample <https://github.com/leonardodalinky/fpsample>`_ acceleration is to be used, if available.""" # noqa: E501
_use_polars_for_constraints: AutoBool = field(
alias="use_polars_for_constraints",
default=AutoBool.AUTO,
converter=AutoBool.from_unstructured, # type: ignore[misc]
)
"""Controls if polars acceleration is to be used for constraints, if available."""
# <<<<< Settings attributes
def __attrs_pre_init__(self) -> None:
# >>>>> Deprecation
flds = fields(Settings)
pairs: list[tuple[str, Attribute]] = [
("BAYBE_NUMPY_USE_SINGLE_PRECISION", flds.float_precision_numpy),
("BAYBE_TORCH_USE_SINGLE_PRECISION", flds.float_precision_torch),
("BAYBE_DEACTIVATE_POLARS", flds._use_polars_for_constraints),
("BAYBE_PARALLEL_SIMULATION_RUNS", flds.parallelize_simulations),
("BAYBE_CACHE_DIR", flds.cache_directory),
]
for env_var, fld in pairs:
if (value := os.environ.pop(env_var, None)) is not None:
warnings.warn(
f"The environment variable '{env_var}' has "
f"been deprecated and support will be dropped in a future version. "
f"Please use 'BAYBE_{(fld.alias or fld.name).upper()}' instead. "
f"For now, we've automatically handled the translation for you.",
DeprecationWarning,
)
if env_var.endswith("SINGLE_PRECISION"):
value = "32" if _to_bool(value) else "64"
elif env_var.endswith("POLARS"):
value = "false" if _to_bool(value) else "true"
elif env_var.endswith("SIMULATION_RUNS"):
value = "true" if _to_bool(value) else "false"
os.environ[f"BAYBE_{(fld.alias or fld.name).upper()}"] = value
# <<<<< Deprecation
known_env_vars = {
f"BAYBE_{attr.alias.upper()}" for attr in self._settings_attributes
}
_validate_whitelist_env_vars(
{
k: v
for k, v in os.environ.items()
if k.startswith("BAYBE_") and k not in known_env_vars
}
)
def __enter__(self) -> Settings:
self.activate()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.restore_previous()
@_use_polars_for_constraints.validator
def _validate_use_polars_for_constraints(self, _, value: AutoBool) -> None:
if value is AutoBool.TRUE and not POLARS_INSTALLED:
raise OptionalImportError(
_MISSING_PACKAGE_ERROR_MESSAGE.format(package_name="polars")
)
@_use_fpsample.validator
def _validate_use_fpsample(self, _, value: AutoBool) -> None:
if value is AutoBool.TRUE and not FPSAMPLE_INSTALLED:
raise OptionalImportError(
_MISSING_PACKAGE_ERROR_MESSAGE.format(package_name="fpsample")
)
@property
def use_polars_for_constraints(self) -> bool:
"""Indicates if Polars is enabled (i.e., installed and set to be used)."""
return self._use_polars_for_constraints.evaluate(lambda: POLARS_INSTALLED)
@use_polars_for_constraints.setter
def use_polars_for_constraints(self, value: AutoBool | bool, /) -> None:
# Note: uses attrs converter
self._use_polars_for_constraints = value # type: ignore[assignment]
@property
def use_fpsample(self) -> bool:
"""Indicates if `fpsample <https://github.com/leonardodalinky/fpsample>`_ is enabled (i.e., installed and set to be used).""" # noqa: E501
return self._use_fpsample.evaluate(lambda: FPSAMPLE_INSTALLED)
@use_fpsample.setter
def use_fpsample(self, value: AutoBool | bool, /) -> None:
# Note: uses attrs converter
self._use_fpsample = value # type: ignore[assignment]
@property
def DTypeFloatNumpy(self) -> type[np.floating]:
"""The floating point data type used for NumPy arrays."""
return getattr(np, f"float{self.float_precision_numpy}")
@property
def DTypeFloatTorch(self) -> torch.dtype:
"""The floating point data type used for Torch tensors."""
import torch
return getattr(torch, f"float{self.float_precision_torch}")
@classproperty
def _internal_attributes(cls) -> frozenset[str]:
"""The names of the internal attributes not representing settings.""" # noqa: D401
# IMPROVE: This approach is not type-safe but the set is needed already at
# class definition time, which means we cannot use `attrs.fields` or similar.
# Perhaps `typing.Annotated` can be used, if there's an elegant way to
# resolve the stringified types coming from `__future__.annotations`?
return frozenset(
{
"_previous_random_state",
"_previous_settings",
"_restore_defaults",
"_restore_environment",
}
)
@classproperty
def _settings_attributes(cls) -> tuple[Attribute, ...]:
"""The attributes representing the available settings.""" # noqa: D401
return tuple(
fld
for fld in fields(Settings)
if fld.name not in Settings._internal_attributes
)
[docs]
def activate(self) -> Settings:
"""Activate the settings globally."""
self._previous_settings = deepcopy(active_settings)
self.overwrite(active_settings)
if self.random_seed is not None:
_RandomState.activate_from_seed(self.random_seed)
return self
[docs]
def restore_previous(self) -> None:
"""Restore the previous settings."""
if self._previous_settings is None:
raise RuntimeError(
"The settings have not yet been activated, "
"so there are no previous settings to restore."
)
self._previous_settings.overwrite(active_settings)
self._previous_settings = None
[docs]
def overwrite(self, target: Settings) -> None:
"""Overwrite the settings of another :class:`Settings` object."""
for fld in self._settings_attributes:
setattr(target, fld.name, getattr(self, fld.name))
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
active_settings = Settings(restore_environment=True)
Settings._global_settings_id = id(active_settings)
"""The currently active global settings instance."""