Source code for baybe.surrogates.gaussian_process.presets.edbo

"""EDBO preset :cite:p:`Shields2021`."""

from __future__ import annotations

import gc
from collections.abc import Collection
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters import TaskParameter
from baybe.parameters.enum import SubstanceEncoding
from baybe.parameters.selectors import (
    ParameterSelectorProtocol,
    TypeSelector,
    to_parameter_selector,
)
from baybe.parameters.substance import SubstanceParameter
from baybe.priors.basic import GammaPrior
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.surrogates.gaussian_process.components.fit_criterion import (
    _MLLForNonTLFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
    _PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
    LikelihoodFactoryProtocol,
)
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory

if TYPE_CHECKING:
    from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
    from torch import Tensor

    from baybe.kernels.base import Kernel
    from baybe.searchspace.core import SearchSpace


def _contains_encoding(
    subspace: SubspaceDiscrete, encodings: Collection[SubstanceEncoding]
) -> bool:
    """Tell if any of the substance parameters uses one of the specified encodings."""
    return any(
        p.encoding in encodings
        for p in subspace.parameters
        if isinstance(p, SubstanceParameter)
    )


_EDBO_ENCODINGS = (
    SubstanceEncoding.MORDRED,
    SubstanceEncoding.RDKIT,
    SubstanceEncoding.RDKIT2DDESCRIPTORS,
)
"""Encodings relevant to EDBO logic."""


[docs] @define class EDBOKernelFactory(_PureKernelFactory): """A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`. GitHub repository: https://github.com/b-shields/edbo Prior settings: https://github.com/b-shields/edbo/blob/9b41eac3f6d9e520547702fd5b0c7ef6441625a4/edbo/bro.py#L658 """ _uses_parameter_names: ClassVar[bool] = True # See base class. parameter_selector: ParameterSelectorProtocol | None = field( factory=lambda: TypeSelector([TaskParameter], exclude=True), converter=to_parameter_selector, ) # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: effective_dims = train_x.shape[-1] switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS ) and (effective_dims >= 50) # low D priors if effective_dims < 5: lengthscale_prior = GammaPrior(1.2, 1.1) lengthscale_initial_value = 0.2 outputscale_prior = GammaPrior(5.0, 0.5) outputscale_initial_value = 8.0 # DFT optimized priors elif switching_condition and effective_dims < 100: lengthscale_prior = GammaPrior(2.0, 0.2) lengthscale_initial_value = 5.0 outputscale_prior = GammaPrior(5.0, 0.5) outputscale_initial_value = 8.0 # Mordred optimized priors elif switching_condition: lengthscale_prior = GammaPrior(2.0, 0.1) lengthscale_initial_value = 10.0 outputscale_prior = GammaPrior(2.0, 0.1) outputscale_initial_value = 10.0 # OHE optimized priors else: lengthscale_prior = GammaPrior(3.0, 1.0) lengthscale_initial_value = 2.0 outputscale_prior = GammaPrior(5.0, 0.2) outputscale_initial_value = 20.0 return ScaleKernel( MaternKernel( nu=2.5, lengthscale_prior=lengthscale_prior, lengthscale_initial_value=lengthscale_initial_value, parameter_names=self.get_parameter_names(searchspace), ), outputscale_prior=outputscale_prior, outputscale_initial_value=outputscale_initial_value, )
EDBOMeanFactory = LazyConstantMeanFactory """A factory providing mean functions for the EDBO preset."""
[docs] @define class EDBOLikelihoodFactory(LikelihoodFactoryProtocol): """A factory providing EDBO likelihoods, as proposed by :cite:p:`Shields2021`. GitHub repository: https://github.com/b-shields/edbo Prior settings: https://github.com/b-shields/edbo/blob/9b41eac3f6d9e520547702fd5b0c7ef6441625a4/edbo/bro.py#L658 """ @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> GPyTorchLikelihood: import torch from gpytorch.likelihoods import GaussianLikelihood effective_dims = train_x.shape[-1] - len( [p for p in searchspace.parameters if isinstance(p, TaskParameter)] ) switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS ) and (effective_dims >= 50) # low D priors if effective_dims < 5: prior = GammaPrior(1.05, 0.5) initial_value = 0.1 # DFT optimized priors elif switching_condition and effective_dims < 100: prior = GammaPrior(1.5, 0.1) initial_value = 5.0 # Mordred optimized priors elif switching_condition: prior = GammaPrior(1.5, 0.1) initial_value = 5.0 # OHE optimized priors else: prior = GammaPrior(1.5, 0.1) initial_value = 5.0 likelihood = GaussianLikelihood(prior.to_gpytorch()) likelihood.noise = torch.tensor([initial_value]) return likelihood
EDBOFitCriterionFactory = _MLLForNonTLFitCriterionFactory() """A factory providing fitting criteria for the EDBO preset.""" # Collect leftover original slotted classes processed by `attrs.define` gc.collect() # Preset defaults KERNEL_FACTORY = EDBOKernelFactory() MEAN_FACTORY = EDBOMeanFactory() LIKELIHOOD_FACTORY = EDBOLikelihoodFactory() FIT_CRITERION_FACTORY = EDBOFitCriterionFactory