Source code for baybe.surrogates.gaussian_process.components.generic
"""Component factories for the Gaussian process surrogate."""
from __future__ import annotations
import gc
import sys
from enum import Enum
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeAlias, TypeVar
from attrs import Attribute, define, field
from typing_extensions import override
from baybe.kernels.base import Kernel
from baybe.searchspace import SearchSpace
from baybe.serialization.core import block_serialization_hook, converter
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion
BayBEGPComponent: TypeAlias = Kernel | FitCriterion
if TYPE_CHECKING:
from gpytorch.kernels import Kernel as GPyTorchKernel
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
from gpytorch.means import Mean as GPyTorchMean
from torch import Tensor
GPyTorchGPComponent: TypeAlias = GPyTorchKernel | GPyTorchMean | GPyTorchLikelihood
GPComponent: TypeAlias = BayBEGPComponent | GPyTorchGPComponent
else:
# At runtime, we use only the BayBE types for serialization compatibility
GPComponent: TypeAlias = BayBEGPComponent
_T_co = TypeVar("_T_co", bound=GPComponent, covariant=True)
[docs]
class GPComponentType(Enum):
"""Enum for Gaussian process component types."""
KERNEL = "KERNEL"
"""Gaussian process kernel."""
MEAN = "MEAN"
"""Gaussian process mean function."""
LIKELIHOOD = "LIKELIHOOD"
"""Gaussian process likelihood."""
CRITERION = "CRITERION"
"""Gaussian process fitting criterion."""
[docs]
def get_types(self) -> tuple[type, ...]:
"""Get the accepted BayBE and GPyTorch types for this component."""
types: list[type[GPComponent]] = []
# Add BayBE type if applicable
if self is GPComponentType.KERNEL:
from baybe.kernels.base import Kernel
types.append(Kernel)
elif self is GPComponentType.CRITERION:
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
)
types.append(FitCriterion)
# Add GPyTorch type if available
if sys.modules.get("gpytorch") is not None:
if self is GPComponentType.KERNEL:
from gpytorch.kernels import Kernel as GPyTorchKernel
types.append(GPyTorchKernel)
elif self is GPComponentType.MEAN:
from gpytorch.means import Mean as GPyTorchMean
types.append(GPyTorchMean)
elif self is GPComponentType.LIKELIHOOD:
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
types.append(GPyTorchLikelihood)
return tuple(types)
def _is_gpytorch_component_class(obj: Any, /) -> bool:
"""Check if a class is a GPyTorch component class using lazy loading."""
if sys.modules.get("gpytorch") is None:
return False
from gpytorch.kernels import Kernel as GPyTorchKernel
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
from gpytorch.means import Mean as GPyTorchMean
return issubclass(obj, (GPyTorchKernel, GPyTorchMean, GPyTorchLikelihood))
def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None:
"""Validate that an object is a BayBE or a GPyTorch GP component."""
if isinstance(value, BayBEGPComponent) or _is_gpytorch_component_class(type(value)):
return
raise TypeError(
f"The object provided for '{attribute.alias}' of "
f"'{instance.__class__.__name__}' must be a BayBE or a GPyTorch GP component. "
f"Got: {type(value)}"
)
[docs]
class GPComponentFactoryProtocol(Protocol, Generic[_T_co]):
"""A protocol defining the interface expected for GP component factories."""
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> _T_co:
"""Create a GP component for the given recommendation context."""
[docs]
@define(frozen=True)
class PlainGPComponentFactory(GPComponentFactoryProtocol[_T_co], SerialMixin):
"""A trivial factory that returns a fixed pre-defined component upon request."""
component: _T_co = field(validator=_validate_component)
"""The fixed component to be returned by the factory."""
@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> _T_co:
return self.component
[docs]
def to_component_factory(
obj: GPComponent | GPComponentFactoryProtocol,
/,
*,
component_type: GPComponentType | None = None,
) -> GPComponentFactoryProtocol:
"""Wrap a component into a plain component factory (with factory passthrough).
Args:
obj: The component or factory to convert.
component_type: An optional restriction on the allowed component type.
Returns:
A component factory.
Raises:
TypeError: If the given component does not match the allowed types.
"""
if isinstance(obj, BayBEGPComponent) or _is_gpytorch_component_class(type(obj)):
if component_type is not None:
allowed_types = component_type.get_types()
if not isinstance(obj, allowed_types):
raise TypeError(
f"Component must be one of {allowed_types}. Got: {type(obj)}"
)
return PlainGPComponentFactory(obj)
return obj
# Block serialization of GPyTorch kernel classes since not yet supported
converter.register_unstructure_hook_factory(
_is_gpytorch_component_class,
lambda _: block_serialization_hook,
)
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()