Source code for baybe.surrogates.gaussian_process.presets.baybe

"""Default preset for Gaussian process surrogates."""

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from typing_extensions import override

from baybe.kernels.base import Kernel
from baybe.kernels.basic import IndexKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.enum import _ParameterKind
from baybe.parameters.selectors import (
    ParameterSelectorProtocol,
    TypeSelector,
    to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.fit_criterion import (
    FitCriterion,
    FitCriterionFactoryProtocol,
)
from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
    SmoothedEDBOKernelFactory,
    SmoothedEDBOLikelihoodFactory,
)

if TYPE_CHECKING:
    from torch import Tensor


[docs] @define class BayBEKernelFactory(_PureKernelFactory): """The default kernel factory for Gaussian process surrogates.""" _supported_parameter_kinds: ClassVar[_ParameterKind] = ( _ParameterKind.REGULAR | _ParameterKind.TASK ) # See base class. @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory is_multitask = searchspace.task_idx is not None factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory return factory()(searchspace, train_x, train_y)
BayBENumericalKernelFactory = SmoothedEDBOKernelFactory """The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501
[docs] @define class BayBETaskKernelFactory(_PureKernelFactory): """The factory providing the default task kernel for Gaussian process surrogates.""" _uses_parameter_names: ClassVar[bool] = True # See base class. _supported_parameter_kinds: ClassVar[_ParameterKind] = _ParameterKind.TASK # See base class. parameter_selector: ParameterSelectorProtocol | None = field( factory=lambda: TypeSelector([TaskParameter]), converter=to_parameter_selector, ) # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: return IndexKernel( num_tasks=searchspace.n_tasks, rank=searchspace.n_tasks, parameter_names=self.get_parameter_names(searchspace), )
BayBEMeanFactory = LazyConstantMeanFactory """The factory providing the default mean function for Gaussian process surrogates.""" BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory """The factory providing the default likelihood for Gaussian process surrogates."""
[docs] @define class BayBEFitCriterionFactory(FitCriterionFactoryProtocol): """The factory providing the default fitting criterion for Gaussian process surrogates.""" # noqa: E501 @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> FitCriterion: return ( FitCriterion.MARGINAL_LOG_LIKELIHOOD if searchspace.task_idx is None else FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD )
# Preset defaults KERNEL_FACTORY = BayBEKernelFactory() MEAN_FACTORY = BayBEMeanFactory() LIKELIHOOD_FACTORY = BayBELikelihoodFactory() FIT_CRITERION_FACTORY = BayBEFitCriterionFactory()