Source code for baybe.constraints.base

"""Base classes for all constraints."""

from __future__ import annotations

import gc
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar

import pandas as pd
from attrs import define, field
from attrs.validators import ge, instance_of, min_len

from baybe.constraints.deprecation import (
    ContinuousLinearEqualityConstraint,
    ContinuousLinearInequalityConstraint,
)
from baybe.serialization import (
    SerialMixin,
)
from baybe.serialization.core import (
    converter,
)
from baybe.utils.basic import classproperty

if TYPE_CHECKING:
    import polars as pl


[docs] @define class Constraint(ABC, SerialMixin): """Abstract base class for all constraints.""" # class variables # TODO: it might turn out these are not needed at a later development stage eval_during_creation: ClassVar[bool] """Class variable encoding whether the condition is evaluated during creation.""" eval_during_modeling: ClassVar[bool] """Class variable encoding whether the condition is evaluated during modeling.""" eval_during_augmentation: ClassVar[bool] = False """Class variable encoding whether the constraint could be considered during data augmentation.""" numerical_only: ClassVar[bool] = False """Class variable encoding whether the constraint is valid only for numerical parameters.""" # Object variables parameters: list[str] = field(validator=min_len(1)) """The list of parameters used for the constraint.""" @parameters.validator def _validate_params( # noqa: DOC101, DOC103 self, _: Any, params: list[str] ) -> None: """Validate the parameter list. Raises: ValueError: If ``params`` contains duplicate values. """ if len(params) != len(set(params)): raise ValueError( f"The given 'parameters' list must have unique values " f"but was: {params}." )
[docs] def summary(self) -> dict: """Return a custom summarization of the constraint.""" constr_dict = dict( Type=self.__class__.__name__, Affected_Parameters=self.parameters ) return constr_dict
@property def is_continuous(self) -> bool: """Boolean indicating if this is a constraint over continuous parameters.""" return isinstance(self, ContinuousConstraint) @property def is_discrete(self) -> bool: """Boolean indicating if this is a constraint over discrete parameters.""" return isinstance(self, DiscreteConstraint) @property def _required_parameters(self) -> set[str]: """All parameter names needed for full constraint evaluation. For most constraints, this is simply the set of names from :attr:`~baybe.constraints.base.Constraint.parameters`. Constraints with additional parameter references (e.g., affected parameters in dependency constraints) override this to include those. """ return set(self.parameters)
[docs] @define class DiscreteConstraint(Constraint, ABC): """Abstract base class for discrete constraints. Discrete constraints use conditions and chain them together to filter unwanted entries from the search space. """ # class variables eval_during_creation: ClassVar[bool] = True # See base class. eval_during_modeling: ClassVar[bool] = False # See base class. def _can_evaluate(self, available: set[str], /) -> bool: """Indicate whether the constraint can be (partially) evaluated. Called to decide if the constraint logic should be invoked at all. The default implementation requires *all* parameters considered by the constraint to be present. Subclasses that support useful partial filtering override this to return ``True`` whenever a meaningful subset is available. Args: available: The set of column names present in the dataframe that is about to be evaluated. Returns: ``True`` if the constraint can apply a meaningful partial filtering given the *available* columns, ``False`` otherwise. """ return self._required_parameters <= available
[docs] def get_valid( self, df: pd.DataFrame, /, *, allow_missing: bool = False ) -> pd.Index: """Get the indices of dataframe entries that are valid under the constraint. Args: df: A dataframe where each row represents a parameter configuration. allow_missing: If ``False``, a :class:`ValueError` is raised when the dataframe is missing required parameter columns. If ``True``, the constraint performs partial filtering on the available columns. Returns: The dataframe indices of rows that fulfill the constraint. """ invalid = self.get_invalid(df, allow_missing=allow_missing) return df.index.drop(invalid)
[docs] def get_invalid( self, df: pd.DataFrame, /, *, allow_missing: bool = False ) -> pd.Index: """Get the indices of dataframe entries that are invalid under the constraint. Args: df: A dataframe where each row represents a parameter configuration. allow_missing: If ``False``, a :class:`ValueError` is raised when the dataframe is missing required parameter columns. If ``True``, the subclass is asked whether it can perform (partial) constraint evaluation; if not, an empty index is returned, signaling to the caller `there are no entries to be excluded *yet*`. Raises: ValueError: If ``allow_missing`` is ``False`` and the dataframe is missing required parameter columns. Returns: The dataframe indices of rows that violate the constraint. """ # TODO: Should switch backends (pandas/polars/...) behind the scenes available = set(df.columns) if not allow_missing: if missing := self._required_parameters - available: raise ValueError( f"'{self.__class__.__name__}' requires columns {missing} " f"which are missing from the dataframe." ) elif not self._can_evaluate(available): return pd.Index([]) return self._get_invalid(df)
@abstractmethod def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index: """Get the indices of invalid entries (core logic for subclasses). This method is only called after it has been confirmed that the dataframe contains sufficient columns for (at least partial) evaluation. Implementations should therefore contain only the constraint's core filtering logic without column-availability checks. Args: df: A dataframe where each row represents a parameter configuration. Returns: The dataframe indices of rows that violate the constraint. """ @classproperty def has_polars_implementation(cls) -> bool: """Whether this constraint class has a Polars implementation.""" return cls.get_invalid_polars is not DiscreteConstraint.get_invalid_polars
[docs] def get_invalid_polars(self) -> pl.Expr: """Translate the constraint to Polars expression identifying undesired rows. Returns: The Polars expressions to pass to :meth:`polars.LazyFrame.filter`. Raises: NotImplementedError: If the constraint class does not have a Polars implementation. """ raise NotImplementedError( f"'{self.__class__.__name__}' does not have a Polars implementation." )
[docs] @define class ContinuousConstraint(Constraint, ABC): """Abstract base class for continuous constraints.""" # class variables eval_during_creation: ClassVar[bool] = False # See base class. eval_during_modeling: ClassVar[bool] = True # See base class. numerical_only: ClassVar[bool] = True
# See base class.
[docs] @define class CardinalityConstraint(Constraint, ABC): r"""Abstract base class for cardinality constraints. Places a constraint on the set of nonzero (i.e. "active") values among the specified parameters, bounding it between the two given integers, i.e. .. math:: \text{min_cardinality} \leq |\{p_i : p_i \neq 0\}| \leq \text{max_cardinality} where :math:`\{p_i\}` are the parameters specified for the constraint. Note that this can be equivalently regarded as L0-constraint on the vector containing the specified parameters. """ # class variable numerical_only: ClassVar[bool] = True # See base class. # object variables min_cardinality: int = field(default=0, validator=[instance_of(int), ge(0)]) "The minimum required cardinality." max_cardinality: int = field(validator=instance_of(int)) "The maximum allowed cardinality." @max_cardinality.default def _default_max_cardinality(self): """Use the number of involved parameters as the upper limit by default.""" return len(self.parameters) def __attrs_post_init__(self): """Validate the cardinality bounds. Raises: ValueError: If the provided cardinality bounds are invalid. ValueError: If the provided cardinality bounds impose no constraint. """ if self.min_cardinality > self.max_cardinality: raise ValueError( f"The lower cardinality bound cannot be larger than the upper bound. " f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}." ) if self.max_cardinality > len(self.parameters): raise ValueError( f"The cardinality bound cannot exceed the number of parameters. " f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}." ) if self.min_cardinality == 0 and self.max_cardinality == len(self.parameters): raise ValueError( f"No constraint of type `{self.__class__.__name__}' is required " f"when the lower cardinality bound is zero and the upper bound equals " f"the number of parameters. Provided values: {self.min_cardinality=}, " f"{self.max_cardinality=}, {len(self.parameters)=}" )
[docs] class ContinuousNonlinearConstraint(ContinuousConstraint, ABC): """Abstract base class for continuous nonlinear constraints."""
# >>>>> Deprecation handling _hook = converter.get_structure_hook(Constraint) def _deprecate_legacy_classes(dct: dict[str, Any], _) -> Constraint: """Enable constraint configs using legacy class names.""" if dct["type"] == "ContinuousLinearEqualityConstraint": dct.pop("type") return ContinuousLinearEqualityConstraint(**dct) elif dct["type"] == "ContinuousLinearInequalityConstraint": dct.pop("type") return ContinuousLinearInequalityConstraint(**dct) return _hook(dct, _) converter.register_structure_hook_func( lambda c: c is Constraint, _deprecate_legacy_classes ) # <<<<< Deprecation handling # Collect leftover original slotted classes processed by `attrs.define` gc.collect()