Source code for baybe.surrogates.gaussian_process.components.fit_criterion

"""Fitting criteria for the Gaussian process surrogate."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

from attrs import define
from typing_extensions import override

if TYPE_CHECKING:
    from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
    from gpytorch.mlls import MarginalLogLikelihood
    from gpytorch.models import GP as GPyTorchModel
    from torch import Tensor

    from baybe.searchspace.core import SearchSpace


[docs] class FitCriterion(Enum): """Available fitting criteria for GP hyperparameter optimization.""" MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD" """Exact marginal log-likelihood.""" LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD" """Leave-one-out cross-validation pseudo-likelihood."""
[docs] def to_gpytorch( self, likelihood: GPyTorchLikelihood, model: GPyTorchModel ) -> MarginalLogLikelihood: """Create the corresponding GPyTorch MLL object.""" import gpytorch mll_class = { FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood, FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501 }[self] return mll_class(likelihood, model)
# Delayed import to avoid circular dependency from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402 GPComponentFactoryProtocol, PlainGPComponentFactory, ) FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion] """A protocol defining the interface for fit criterion factories.""" PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion] """A trivial factory that returns a fixed fit criterion.""" @define class _MLLForNonTLFitCriterionFactory(FitCriterionFactoryProtocol): """A fit criterion factory switching between MLL and BayBE default. In transfer learning contexts, delegates to :class:`baybe.surrogates.gaussian_process.presets.baybe.BayBEFitCriterionFactory`. """ @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> FitCriterion: if searchspace.task_idx is None: return FitCriterion.MARGINAL_LOG_LIKELIHOOD from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEFitCriterionFactory, ) return BayBEFitCriterionFactory()(searchspace, train_x, train_y)