Source code for baybe.surrogates.gaussian_process.components.mean

"""Mean factories for the Gaussian process surrogate."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from attrs import define
from typing_extensions import override

from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.generic import (
    GPComponentFactoryProtocol,
    PlainGPComponentFactory,
)

if TYPE_CHECKING:
    from gpytorch.means import Mean as GPyTorchMean
    from torch import Tensor

    MeanFactoryProtocol = GPComponentFactoryProtocol[GPyTorchMean]
    PlainMeanFactory = PlainGPComponentFactory[GPyTorchMean]
else:
    # At runtime, we avoid loading GPyTorch eagerly for performance reasons
    MeanFactoryProtocol = GPComponentFactoryProtocol[Any]
    PlainMeanFactory = PlainGPComponentFactory[Any]


[docs] @define class LazyConstantMeanFactory(MeanFactoryProtocol): """A factory providing constant mean functions using lazy loading.""" @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> GPyTorchMean: from gpytorch.means import ConstantMean return ConstantMean()