Source code for baybe.surrogates.gaussian_process.components.likelihood

"""Likelihood factories for the Gaussian process surrogate."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from attrs import define
from typing_extensions import override

from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.generic import (
    GPComponentFactoryProtocol,
    PlainGPComponentFactory,
)

if TYPE_CHECKING:
    from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
    from torch import Tensor

    LikelihoodFactoryProtocol = GPComponentFactoryProtocol[GPyTorchLikelihood]
    PlainLikelihoodFactory = PlainGPComponentFactory[GPyTorchLikelihood]
else:
    # At runtime, we avoid loading GPyTorch eagerly for performance reasons
    LikelihoodFactoryProtocol = GPComponentFactoryProtocol[Any]
    PlainLikelihoodFactory = PlainGPComponentFactory[Any]


[docs] @define class LazyGaussianLikelihoodFactory(LikelihoodFactoryProtocol): """A factory providing Gaussian likelihoods using lazy loading.""" @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> GPyTorchLikelihood: from gpytorch.likelihoods import GaussianLikelihood return GaussianLikelihood()