Source code for baybe.parameters.selectors

"""Parameter selectors."""

import re
from abc import ABC, abstractmethod
from collections.abc import Collection
from typing import Protocol

from attrs import Converter, define, field
from attrs.validators import deep_iterable, instance_of, min_len
from typing_extensions import override

from baybe.parameters.base import Parameter
from baybe.utils.basic import to_tuple
from baybe.utils.conversion import nonstring_to_tuple


[docs] class ParameterSelectorProtocol(Protocol): """Type protocol specifying the interface parameter selectors need to implement.""" def __call__(self, parameter: Parameter) -> bool: """Determine if a parameter should be included in the selection."""
[docs] @define class ParameterSelector(ParameterSelectorProtocol, ABC): """Base class for parameter selectors.""" exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True) """Boolean flag indicating whether invert the selection criterion.""" @abstractmethod def _is_match(self, parameter: Parameter) -> bool: """Determine if a parameter meets the selection criterion.""" @override def __call__(self, parameter: Parameter) -> bool: """Determine if a parameter should be included in the selection.""" result = self._is_match(parameter) return not result if self.exclude else result
[docs] @define class TypeSelector(ParameterSelector): """Select parameters by type.""" types: tuple[type[Parameter], ...] = field(converter=to_tuple) """The parameter types to be selected.""" @override def _is_match(self, parameter: Parameter) -> bool: return isinstance(parameter, self.types)
[docs] @define class NameSelector(ParameterSelector): """Select parameters by name patterns.""" patterns: tuple[str, ...] = field( converter=Converter( # type: ignore nonstring_to_tuple, takes_self=True, takes_field=True ), validator=[ min_len(1), deep_iterable(member_validator=instance_of(str)), ], ) """The patterns to be matched against.""" regex: bool = field(default=True, validator=instance_of(bool), kw_only=True) """If ``False``, the provided patterns are interpreted as literal strings.""" @override def _is_match(self, parameter: Parameter) -> bool: if self.regex: return any(re.fullmatch(p, parameter.name) for p in self.patterns) return parameter.name in self.patterns
[docs] def to_parameter_selector( x: ( str | type[Parameter] | Collection[str] | Collection[type[Parameter]] | ParameterSelectorProtocol ), /, ) -> ParameterSelectorProtocol: """Convert shorthand notations to parameter selectors. Convenience converter that allows users to specify parameter selectors using simpler types: * A callable (i.e., an existing selector or any object satisfying :class:`ParameterSelectorProtocol`) is passed through unchanged. * A single string is interpreted as a parameter name and wrapped into a :class:`NameSelector`. * A single :class:`~baybe.parameters.base.Parameter` subclass is wrapped into a :class:`TypeSelector`. * A collection of strings is converted to a :class:`NameSelector`. * A collection of :class:`~baybe.parameters.base.Parameter` subclasses is converted to a :class:`TypeSelector`. Args: x: The object to convert. Returns: The corresponding parameter selector. Raises: TypeError: If the input cannot be converted to a parameter selector. """ if isinstance(x, str): return NameSelector([x]) if isinstance(x, type) and issubclass(x, Parameter): return TypeSelector([x]) if callable(x): return x # At this point, x should be a collection of strings or parameter types items = tuple(x) if all(isinstance(item, str) for item in items): return NameSelector(items) if all(isinstance(item, type) and issubclass(item, Parameter) for item in items): return TypeSelector(items) raise TypeError(f"Cannot convert {x!r} to a parameter selector.")