Source code for baybe.surrogates.gaussian_process.presets.botorch

"""BoTorch preset for Gaussian process surrogates."""

from __future__ import annotations

import gc
from itertools import chain
from typing import TYPE_CHECKING, ClassVar

from attrs import define
from typing_extensions import override

from baybe.kernels.base import Kernel
from baybe.parameters.enum import _ParameterKind
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components import LikelihoodFactoryProtocol
from baybe.surrogates.gaussian_process.components.fit_criterion import (
    FitCriterion,
    PlainFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
    ICMKernelFactory,
    _PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.mean import MeanFactoryProtocol

if TYPE_CHECKING:
    from gpytorch.kernels import Kernel as GPyTorchKernel
    from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
    from gpytorch.means import Mean as GPyTorchMean
    from torch import Tensor


[docs] @define class BotorchKernelFactory(_PureKernelFactory): """A factory providing BoTorch kernels.""" _uses_parameter_names: ClassVar[bool] = True # See base class. _supported_parameter_kinds: ClassVar[_ParameterKind] = ( _ParameterKind.REGULAR | _ParameterKind.TASK ) # See base class. @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel | GPyTorchKernel: from botorch.models.kernels.positive_index import PositiveIndexKernel from botorch.models.utils.gpytorch_modules import ( get_covar_module_with_dim_scaled_prior, ) parameter_names = self.get_parameter_names(searchspace) # For regular parameters, resolve parameter names to active dimension indices active_dims = list( chain.from_iterable( searchspace.get_comp_rep_parameter_indices(name) for name in parameter_names if searchspace.get_parameters_by_name([name])[0]._kind is _ParameterKind.REGULAR ) ) ard_num_dims = len(active_dims) # Create the base kernel for the regular parameters base_kernel = get_covar_module_with_dim_scaled_prior( ard_num_dims=ard_num_dims, active_dims=active_dims ) # Single-task case if (task_idx := searchspace.task_idx) is None: return base_kernel index_kernel = PositiveIndexKernel( num_tasks=searchspace.n_tasks, rank=searchspace.n_tasks, active_dims=[task_idx], ) return ICMKernelFactory(base_kernel, index_kernel)( searchspace, train_x, train_y )
[docs] class BotorchMeanFactory(MeanFactoryProtocol): """A factory providing BoTorch mean functions.""" @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> GPyTorchMean: from gpytorch.means import ConstantMean from baybe.surrogates.gaussian_process.components._gpytorch import ( HadamardConstantMean, ) if searchspace.n_tasks == 1: return ConstantMean() assert searchspace.task_idx is not None return HadamardConstantMean( ConstantMean(), searchspace.n_tasks, searchspace.task_idx )
[docs] class BotorchLikelihoodFactory(LikelihoodFactoryProtocol): """A factory providing BoTorch likelihoods.""" @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> GPyTorchLikelihood: if searchspace.n_tasks == 1: from botorch.models.utils.gpytorch_modules import ( get_gaussian_likelihood_with_lognormal_prior, ) return get_gaussian_likelihood_with_lognormal_prior() from baybe.surrogates.gaussian_process.components._gpytorch import ( make_botorch_multitask_likelihood, ) assert searchspace.task_idx is not None return make_botorch_multitask_likelihood( num_tasks=searchspace.n_tasks, task_feature=searchspace.task_idx )
# Collect leftover original slotted classes processed by `attrs.define` gc.collect() # Aliases for generic preset imports KERNEL_FACTORY = BotorchKernelFactory() MEAN_FACTORY = BotorchMeanFactory() LIKELIHOOD_FACTORY = BotorchLikelihoodFactory() FIT_CRITERION_FACTORY = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)