Source code for baybe.surrogates.gaussian_process.presets.chen
"""Preset for adaptive kernel hyperpriors proposed by :cite:p:`Chen2026`."""
from __future__ import annotations
import gc
import math
from typing import TYPE_CHECKING, ClassVar
from attrs import define, field
from typing_extensions import override
from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.fit_criterion import (
_MLLForNonTLFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
LazyGaussianLikelihoodFactory,
)
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
if TYPE_CHECKING:
from torch import Tensor
from baybe.kernels.base import Kernel
from baybe.searchspace.core import SearchSpace
[docs]
@define
class CHENKernelFactory(_PureKernelFactory):
"""A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501
_uses_parameter_names: ClassVar[bool] = True
# See base class.
parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)
@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0
lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0)
lengthscale_initial_value = lengthscale
outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0)
outputscale_initial_value = lengthscale
return ScaleKernel(
MaternKernel(
nu=2.5,
lengthscale_prior=lengthscale_prior,
lengthscale_initial_value=lengthscale_initial_value,
parameter_names=self.get_parameter_names(searchspace),
),
outputscale_prior=outputscale_prior,
outputscale_initial_value=outputscale_initial_value,
)
CHENFitCriterionFactory = _MLLForNonTLFitCriterionFactory()
"""A factory providing fitting criteria for the CHEN preset."""
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
# Preset defaults
KERNEL_FACTORY = CHENKernelFactory()
MEAN_FACTORY = LazyConstantMeanFactory()
LIKELIHOOD_FACTORY = LazyGaussianLikelihoodFactory()
FIT_CRITERION_FACTORY = CHENFitCriterionFactory