Source code for baybe.surrogates.gaussian_process.components.mean
"""Mean factories for the Gaussian process surrogate."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from attrs import define
from typing_extensions import override
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.generic import (
GPComponentFactoryProtocol,
PlainGPComponentFactory,
)
if TYPE_CHECKING:
from gpytorch.means import Mean as GPyTorchMean
from torch import Tensor
MeanFactoryProtocol = GPComponentFactoryProtocol[GPyTorchMean]
PlainMeanFactory = PlainGPComponentFactory[GPyTorchMean]
else:
# At runtime, we avoid loading GPyTorch eagerly for performance reasons
MeanFactoryProtocol = GPComponentFactoryProtocol[Any]
PlainMeanFactory = PlainGPComponentFactory[Any]
[docs]
@define
class LazyConstantMeanFactory(MeanFactoryProtocol):
"""A factory providing constant mean functions using lazy loading."""
@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> GPyTorchMean:
from gpytorch.means import ConstantMean
return ConstantMean()