Source code for baybe.kernels.base

"""Base classes for all kernels."""

from __future__ import annotations

import gc
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING, Any

from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.settings import active_settings
from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple

if TYPE_CHECKING:
    from baybe.surrogates.gaussian_process.components.kernel import PlainKernelFactory


[docs] @define(frozen=True) class Kernel(ABC, SerialMixin): """Abstract base class for all kernels.""" def __add__(self, other: Any) -> Kernel: """Create a sum kernel from two kernels. Flattens nested sums so that ``(a + b) + c`` yields ``AdditiveKernel([a, b, c])`` instead of ``AdditiveKernel([AdditiveKernel([a, b]), c])``. """ if isinstance(other, Kernel): from baybe.kernels.composite import AdditiveKernel left = self.base_kernels if isinstance(self, AdditiveKernel) else (self,) right = ( other.base_kernels if isinstance(other, AdditiveKernel) else (other,) ) return AdditiveKernel([*left, *right]) return NotImplemented def __radd__(self, other: Any) -> Kernel: """Support right-hand addition for kernel objects.""" # Enable use with built-in sum(), which starts with 0 + first_element. if other == 0: return self if isinstance(other, Kernel): return self.__add__(other) return NotImplemented def __mul__(self, other: Any) -> Kernel: """Create a product kernel or scale kernel. When multiplied with another kernel, a product kernel is created. Nested products are flattened so that ``(a * b) * c`` yields ``ProductKernel([a, b, c])``. When multiplied with a numeric constant, a scale kernel with a fixed (non-trainable) output scale is created. """ if isinstance(other, Kernel): from baybe.kernels.composite import ProductKernel left = self.base_kernels if isinstance(self, ProductKernel) else (self,) right = other.base_kernels if isinstance(other, ProductKernel) else (other,) return ProductKernel([*left, *right]) if isinstance(other, (int, float)): if other == 1: return self from baybe.kernels.composite import ScaleKernel return ScaleKernel( base_kernel=self, outputscale_initial_value=float(other), outputscale_trainable=False, ) return NotImplemented def __rmul__(self, other: Any) -> Kernel: """Support right-hand multiplication, enabling ``constant * kernel``.""" # Enable use with math.prod(), which starts with 1 * first_element. return self.__mul__(other) @classproperty def _whitelisted_attributes(cls) -> frozenset[str]: """Attribute names to exclude from gpytorch matching.""" return frozenset()
[docs] def to_factory(self) -> PlainKernelFactory: """Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501 from baybe.surrogates.gaussian_process.components.kernel import ( PlainKernelFactory, ) return PlainKernelFactory(self)
@abstractmethod def _get_dimensions( self, searchspace: SearchSpace ) -> tuple[tuple[int, ...] | None, int | None]: """Get the active dimensions and the number of ARD dimensions."""
[docs] def to_gpytorch(self, searchspace: SearchSpace): """Create the gpytorch representation of the kernel.""" import gpytorch.kernels active_dims, ard_num_dims = self._get_dimensions(searchspace) # Get corresponding gpytorch kernel class and its base classes try: kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__) except AttributeError: import botorch.models.kernels.positive_index kernel_cls = getattr( botorch.models.kernels.positive_index, self.__class__.__name__ ) base_classes = get_baseclasses(kernel_cls, abstract=True) # Fetch the necessary gpytorch constructor parameters of the kernel. # NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled # via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to # just check the fields of the actual class, but also those of its base # classes. kernel_attrs: dict[str, Any] = {} unmatched_attrs: dict[str, Any] = {} for cls in [kernel_cls, *base_classes]: matched, unmatched = match_attributes(self, cls.__init__, strict=False) kernel_attrs.update(matched) unmatched_attrs.update(unmatched) # Sanity check: all attributes of the BayBE kernel need a corresponding match # with the gpytorch kernel signature (otherwise, the BayBE kernel class is # misconfigured). Exceptions: initial values and trainability flags are not used # during construction but are set on the created object after construction. missing = ( set(unmatched_attrs) - set(kernel_attrs) - self._whitelisted_attributes ) if leftover := { m for m in missing if not m.endswith("_initial_value") and not m.endswith("_trainable") }: raise UnmatchedAttributeError(leftover) # Convert specified priors to gpytorch, if provided prior_dict = { key: value.to_gpytorch() for key, value in kernel_attrs.items() if isinstance(value, Prior) } # Convert specified inner kernels to gpytorch, if provided kernel_dict = { key: value.to_gpytorch(searchspace) for key, value in kernel_attrs.items() if isinstance(value, Kernel) } # Create the kernel with all its inner gpytorch objects kernel_attrs.update(kernel_dict) kernel_attrs.update(prior_dict) gpytorch_kernel = kernel_cls( **kernel_attrs, ard_num_dims=ard_num_dims, active_dims=active_dims ) # If the kernel has a lengthscale, set its initial value if kernel_cls.has_lengthscale: import torch # We can ignore mypy here and simply assume that the corresponding BayBE # kernel class has the necessary lengthscale attribute defined. This is # safer than using a `hasattr` check in the above if-condition since for # the latter the code would silently fail when forgetting to add the # attribute to a new kernel class / misspelling it. if (initial_value := self.lengthscale_initial_value) is not None: # type: ignore[attr-defined] gpytorch_kernel.lengthscale = torch.tensor( initial_value, dtype=active_settings.DTypeFloatTorch ) return gpytorch_kernel
[docs] @define(frozen=True) class BasicKernel(Kernel, ABC): """Abstract base class for all basic kernels.""" parameter_names: tuple[str, ...] | None = field( default=None, converter=optional_c(to_tuple), validator=optional_v( deep_iterable( iterable_validator=instance_of(tuple), member_validator=instance_of(str), ) ), kw_only=True, ) """An optional set of names specifiying the parameters the kernel should act on.""" @override @classproperty def _whitelisted_attributes(cls) -> frozenset[str]: return frozenset({"parameter_names"}) @override def _get_dimensions( self, searchspace: SearchSpace ) -> tuple[tuple[int, ...] | None, int | None]: if self.parameter_names is None: # `None` is gpytorch's default indicating that all dimensions are active active_dims = None else: active_dims = tuple( chain.from_iterable( searchspace.get_comp_rep_parameter_indices(name) for name in self.parameter_names ) ) # We use automatic relevance determination for all kernels ard_num_dims = ( len(active_dims) if active_dims is not None else len(searchspace.comp_rep_columns) ) return active_dims, ard_num_dims
[docs] @define(frozen=True) class CompositeKernel(Kernel, ABC): """Abstract base class for all composite kernels.""" @override def _get_dimensions( self, searchspace: SearchSpace ) -> tuple[tuple[int, ...] | None, int | None]: return None, None
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()