diff --git a/src/citrine/informatics/design_spaces/__init__.py b/src/citrine/informatics/design_spaces/__init__.py index 0b4804802..99fa5bc6f 100644 --- a/src/citrine/informatics/design_spaces/__init__.py +++ b/src/citrine/informatics/design_spaces/__init__.py @@ -2,7 +2,8 @@ from .data_source_design_space import * from .design_space import * from .design_space_settings import * -from .enumerated_design_space import * from .formulation_design_space import * -from .product_design_space import * from .hierarchical_design_space import * +from .product_design_space import * +from .subspace import * +from .top_level_design_space import * diff --git a/src/citrine/informatics/design_spaces/data_source_design_space.py b/src/citrine/informatics/design_spaces/data_source_design_space.py index dc8ef4409..174a4f208 100644 --- a/src/citrine/informatics/design_spaces/data_source_design_space.py +++ b/src/citrine/informatics/design_spaces/data_source_design_space.py @@ -1,12 +1,12 @@ -from citrine._rest.engine_resource import EngineResource +from citrine._rest.resource import Resource from citrine._serialization import properties from citrine.informatics.data_sources import DataSource -from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.subspace import DesignSubspace __all__ = ['DataSourceDesignSpace'] -class DataSourceDesignSpace(EngineResource['DataSourceDesignSpace'], DesignSpace): +class DataSourceDesignSpace(Resource['DataSourceDesignSpace'], DesignSubspace): """An enumeration of candidates stored in a data source. Parameters @@ -20,10 +20,9 @@ class DataSourceDesignSpace(EngineResource['DataSourceDesignSpace'], DesignSpace """ - data_source = properties.Object(DataSource, 'data.instance.data_source') + data_source = properties.Object(DataSource, 'data_source') - typ = properties.String('data.instance.type', default='DataSourceDesignSpace', - deserializable=False) + typ = properties.String('type', default='DataSourceDesignSpace', deserializable=False) def __init__(self, name: str, diff --git a/src/citrine/informatics/design_spaces/design_space.py b/src/citrine/informatics/design_spaces/design_space.py index 295d4094e..c87ee28ce 100644 --- a/src/citrine/informatics/design_spaces/design_space.py +++ b/src/citrine/informatics/design_spaces/design_space.py @@ -1,87 +1,10 @@ -"""Tools for working with design spaces.""" -from typing import Optional, Type -from uuid import UUID - -from citrine._rest.asynchronous_object import AsynchronousObject -from citrine._serialization import properties -from citrine._serialization.polymorphic_serializable import PolymorphicSerializable -from citrine._serialization.serializable import Serializable -from citrine._session import Session -from citrine.resources.sample_design_space_execution import \ - SampleDesignSpaceExecutionCollection - - __all__ = ['DesignSpace'] -class DesignSpace(PolymorphicSerializable['DesignSpace'], AsynchronousObject): - """A Citrine Design Space describes the set of materials that can be made. +class DesignSpace: + """Parent type of the individual design space on the Citrine Platform. - Abstract type that returns the proper type given a serialized dict. + The Predictor class unifies TopLevelDesignSpace types and the DesignSubspace types, + but does not provide any functionality on its own. """ - - uid = properties.Optional(properties.UUID, 'id', serializable=False) - """:Optional[UUID]: Citrine Platform unique identifier""" - name = properties.String('data.name') - description = properties.Optional(properties.String(), 'data.description') - - locked_by = properties.Optional(properties.UUID, 'metadata.locked.user', - serializable=False) - """:Optional[UUID]: id of the user whose action cause the design space to - be locked, if it is locked""" - lock_time = properties.Optional(properties.Datetime, 'metadata.locked.time', - serializable=False) - """:Optional[datetime]: date and time at which the resource was locked, - if it is locked""" - - @staticmethod - def wrap_instance(subspace_data: dict) -> dict: - """Insert a serialized embedded design space into an entity envelope. - - This facilitates deserialization. - """ - return { - "data": { - "name": subspace_data.get("name", ""), - "description": subspace_data.get("description", ""), - "instance": subspace_data - } - } - - _response_key = None - _project_id: Optional[UUID] = None - _session: Optional[Session] = None - _in_progress_statuses = ["VALIDATING", "CREATED"] - _succeeded_statuses = ["READY"] - _failed_statuses = ["INVALID", "ERROR"] - - @classmethod - def get_type(cls, data) -> Type[Serializable]: - """Return the subtype.""" - from .data_source_design_space import DataSourceDesignSpace - from .enumerated_design_space import EnumeratedDesignSpace - from .formulation_design_space import FormulationDesignSpace - from .product_design_space import ProductDesignSpace - from .hierarchical_design_space import HierarchicalDesignSpace - - return { - 'Univariate': ProductDesignSpace, - 'ProductDesignSpace': ProductDesignSpace, - 'EnumeratedDesignSpace': EnumeratedDesignSpace, - 'FormulationDesignSpace': FormulationDesignSpace, - 'DataSourceDesignSpace': DataSourceDesignSpace, - 'HierarchicalDesignSpace': HierarchicalDesignSpace - }[data['data']['instance']['type']] - - @property - def is_locked(self) -> bool: - """If is_locked is true, edits to the design space will be rejected.""" - return self.locked_by is not None - - @property - def sample_design_space_executions(self): - """Start a Sample Design Space Execution using the current Design Space.""" - return SampleDesignSpaceExecutionCollection( - project_id=self._project_id, design_space_id=self.uid, session=self._session - ) diff --git a/src/citrine/informatics/design_spaces/enumerated_design_space.py b/src/citrine/informatics/design_spaces/enumerated_design_space.py deleted file mode 100644 index 0431b69a6..000000000 --- a/src/citrine/informatics/design_spaces/enumerated_design_space.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import List, Mapping, Union - -from citrine._rest.engine_resource import EngineResource -from citrine._serialization import properties -from citrine.informatics.descriptors import Descriptor -from citrine.informatics.design_spaces.design_space import DesignSpace - -__all__ = ['EnumeratedDesignSpace'] - - -class EnumeratedDesignSpace(EngineResource['EnumeratedDesignSpace'], DesignSpace): - """An explicit enumeration of candidate materials to score. - - Enumerated design spaces are intended to capture small spaces with fewer than - 1000 values. For larger spaces, use the DataSourceDesignSpace. - - Parameters - ---------- - name:str - the name of the design space - description:str - the description of the design space - descriptors: list[Descriptor] - the list of descriptors included in the candidates of the design space - data: list[dict] - list of dicts of the shape `{: }` - where each dict corresponds to a candidate in the design space - - """ - - descriptors = properties.List(properties.Object(Descriptor), 'data.instance.descriptors') - _data = properties.List(properties.Mapping(properties.String, - properties.Union([properties.String(), - properties.Integer(), - properties.Float()])), - 'data.instance.data') - - typ = properties.String('data.instance.type', default='EnumeratedDesignSpace', - deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - descriptors: List[Descriptor], - data: List[Mapping[str, Union[int, float, str]]]): - self.name: str = name - self.description: str = description - self.descriptors: List[Descriptor] = descriptors - self.data: List[Mapping[str, Union[int, float, str]]] = data - - def __str__(self): - return ''.format(self.name) - - @property - def data(self) -> List[Mapping[str, Union[int, float, str]]]: - """List of dicts corresponding to candidates in the design space.""" - return self._data - - @data.setter - def data(self, value: List[Mapping[str, str]]): - self._data = value diff --git a/src/citrine/informatics/design_spaces/formulation_design_space.py b/src/citrine/informatics/design_spaces/formulation_design_space.py index a77e65b2d..216aec80b 100644 --- a/src/citrine/informatics/design_spaces/formulation_design_space.py +++ b/src/citrine/informatics/design_spaces/formulation_design_space.py @@ -1,15 +1,15 @@ from typing import Mapping, Optional, Set -from citrine._rest.engine_resource import EngineResource +from citrine._rest.resource import Resource from citrine._serialization import properties from citrine.informatics.constraints import Constraint from citrine.informatics.descriptors import FormulationDescriptor -from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.subspace import DesignSubspace __all__ = ['FormulationDesignSpace'] -class FormulationDesignSpace(EngineResource['FormulationDesignSpace'], DesignSpace): +class FormulationDesignSpace(Resource['FormulationDesignSpace'], DesignSubspace): """Design space composed of mixtures of ingredients. Parameters @@ -36,23 +36,16 @@ class FormulationDesignSpace(EngineResource['FormulationDesignSpace'], DesignSpa """ - formulation_descriptor = properties.Object( - FormulationDescriptor, - 'data.instance.formulation_descriptor' - ) - ingredients = properties.Set(properties.String, 'data.instance.ingredients') + formulation_descriptor = properties.Object(FormulationDescriptor, 'formulation_descriptor') + ingredients = properties.Set(properties.String, 'ingredients') labels = properties.Optional(properties.Mapping( properties.String, properties.Set(properties.String) - ), 'data.instance.labels') - constraints = properties.Set(properties.Object(Constraint), 'data.instance.constraints') - resolution = properties.Float('data.instance.resolution') + ), 'labels') + constraints = properties.Set(properties.Object(Constraint), 'constraints') + resolution = properties.Float('resolution') - typ = properties.String( - 'data.instance.type', - default='FormulationDesignSpace', - deserializable=False - ) + typ = properties.String('type', default='FormulationDesignSpace', deserializable=False) def __init__(self, name: str, diff --git a/src/citrine/informatics/design_spaces/hierarchical_design_space.py b/src/citrine/informatics/design_spaces/hierarchical_design_space.py index 205441820..3a31e8b05 100644 --- a/src/citrine/informatics/design_spaces/hierarchical_design_space.py +++ b/src/citrine/informatics/design_spaces/hierarchical_design_space.py @@ -7,7 +7,7 @@ from citrine.informatics.data_sources import DataSource from citrine.informatics.dimensions import Dimension from citrine.informatics.design_spaces import FormulationDesignSpace -from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.top_level_design_space import TopLevelDesignSpace from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings __all__ = [ @@ -108,7 +108,7 @@ def __repr__(self): return f"" -class HierarchicalDesignSpace(EngineResource["HierarchicalDesignSpace"], DesignSpace): +class HierarchicalDesignSpace(EngineResource["HierarchicalDesignSpace"], TopLevelDesignSpace): """A design space that produces hierarchical candidates representing a material history. A hierarchical design space always contains a root node that defines the @@ -185,41 +185,7 @@ def _post_dump(self, data: dict) -> dict: if self._settings: data["settings"] = self._settings.dump() - root_node = data["instance"]["root"] - data["instance"]["root"] = self.__unwrap_node(root_node) - - data["instance"]["subspaces"] = [ - self.__unwrap_node(sub_node) - for sub_node in data['instance']['subspaces'] - ] - return data - - @classmethod - def _pre_build(cls, data: dict) -> dict: - root_node = data["data"]["instance"]["root"] - data["data"]["instance"]["root"] = cls.__wrap_node(root_node) - - data["data"]["instance"]["subspaces"] = [ - cls.__wrap_node(sub_node) for sub_node in data['data']['instance']['subspaces'] - ] - return data - @staticmethod - def __wrap_node(node_data: dict) -> dict: - formulation_subspace = node_data.pop('formulation', None) - if formulation_subspace: - node_data['formulation'] = DesignSpace.wrap_instance(formulation_subspace) - return node_data - - @staticmethod - def __unwrap_node(node_data: dict) -> dict: - formulation_subspace = node_data.pop('formulation', None) - if formulation_subspace: - node_data['formulation'] = formulation_subspace['data']['instance'] - node_data['formulation']['name'] = formulation_subspace['data']['name'] - node_data['formulation']['description'] = formulation_subspace['data']['description'] - return node_data - def __repr__(self): return f'' diff --git a/src/citrine/informatics/design_spaces/product_design_space.py b/src/citrine/informatics/design_spaces/product_design_space.py index d52f6a640..ad0fed6be 100644 --- a/src/citrine/informatics/design_spaces/product_design_space.py +++ b/src/citrine/informatics/design_spaces/product_design_space.py @@ -3,14 +3,15 @@ from citrine._rest.engine_resource import EngineResource from citrine._serialization import properties -from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.top_level_design_space import TopLevelDesignSpace from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings +from citrine.informatics.design_spaces.subspace import DesignSubspace from citrine.informatics.dimensions import Dimension __all__ = ['ProductDesignSpace'] -class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): +class ProductDesignSpace(EngineResource['ProductDesignSpace'], TopLevelDesignSpace): """A Cartesian product of design spaces. Factors can be other design spaces and/or univariate dimensions. @@ -21,9 +22,8 @@ class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): the name of the design space description:str the description of the design space - subspaces: List[Union[UUID, DesignSpace]] - the list of subspaces to combine, either design spaces defined in-line - or UUIDs that reference design spaces on the platform + subspaces: List[Union[UUID, DesignSubspace]] + the list of subspaces to combine, defined in-line dimensions: list[Dimension] univariate dimensions that are factors of the design space; can be enumerated or continuous @@ -31,7 +31,7 @@ class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): _settings = properties.Optional(properties.Object(DesignSpaceSettings), "metadata.settings") - subspaces = properties.List(properties.Object(DesignSpace), 'data.instance.subspaces', + subspaces = properties.List(properties.Object(DesignSubspace), 'data.instance.subspaces', default=[]) dimensions = properties.Optional( properties.List(properties.Object(Dimension)), 'data.instance.dimensions' @@ -44,11 +44,11 @@ def __init__(self, name: str, *, description: str, - subspaces: Optional[List[Union[UUID, DesignSpace]]] = None, + subspaces: Optional[List[Union[UUID, DesignSubspace]]] = None, dimensions: Optional[List[Dimension]] = None): self.name: str = name self.description: str = description - self.subspaces: List[Union[UUID, DesignSpace]] = subspaces or [] + self.subspaces: List[Union[UUID, DesignSubspace]] = subspaces or [] self.dimensions: List[Dimension] = dimensions or [] def _post_dump(self, data: dict) -> dict: @@ -57,17 +57,6 @@ def _post_dump(self, data: dict) -> dict: if self._settings: data["settings"] = self._settings.dump() - for i, subspace in enumerate(data['instance']['subspaces']): - if isinstance(subspace, dict): - # embedded design spaces are not modules, so only serialize their config - data['instance']['subspaces'][i] = subspace['instance'] - return data - - @classmethod - def _pre_build(cls, data: dict) -> dict: - for i, subspace_data in enumerate(data['data']['instance']['subspaces']): - if isinstance(subspace_data, dict): - data['data']['instance']['subspaces'][i] = DesignSpace.wrap_instance(subspace_data) return data def __str__(self): diff --git a/src/citrine/informatics/design_spaces/subspace.py b/src/citrine/informatics/design_spaces/subspace.py new file mode 100644 index 000000000..a6a3a614e --- /dev/null +++ b/src/citrine/informatics/design_spaces/subspace.py @@ -0,0 +1,38 @@ +from typing import Type + +from citrine._serialization import properties +from citrine._serialization.polymorphic_serializable import PolymorphicSerializable +from citrine.informatics.design_spaces.design_space import DesignSpace + + +class DesignSubspace(PolymorphicSerializable["DesignSubspace"], DesignSpace): + """An individual subspace within a Design Space. + + A DesignSubspace cannot be registered to the Citrine Platform by itself + and must be included as a component within a ProductDesignSpace or + HierarchicalDesignSpace to be used. + + """ + + name = properties.String("name") + description = properties.Optional(properties.String(), "description") + + @classmethod + def get_type(cls, data) -> Type['DesignSubspace']: + """Return the subtype.""" + from .data_source_design_space import DataSourceDesignSpace + from .formulation_design_space import FormulationDesignSpace + + type_dict = { + 'FormulationDesignSpace': FormulationDesignSpace, + 'DataSourceDesignSpace': DataSourceDesignSpace, + } + + typ = type_dict.get(data['type']) + if typ is not None: + return typ + else: + raise ValueError( + '{} is not a valid design subspace type. ' + 'Must be in {}.'.format(data['type'], type_dict.keys()) + ) diff --git a/src/citrine/informatics/design_spaces/top_level_design_space.py b/src/citrine/informatics/design_spaces/top_level_design_space.py new file mode 100644 index 000000000..033178e31 --- /dev/null +++ b/src/citrine/informatics/design_spaces/top_level_design_space.py @@ -0,0 +1,80 @@ +"""Tools for working with design spaces.""" +from typing import Optional, Type +from uuid import UUID + +from citrine._rest.asynchronous_object import AsynchronousObject +from citrine._serialization import properties +from citrine._serialization.polymorphic_serializable import PolymorphicSerializable +from citrine._serialization.serializable import Serializable +from citrine._session import Session +from citrine.resources.sample_design_space_execution import \ + SampleDesignSpaceExecutionCollection + + +__all__ = ['TopLevelDesignSpace'] + + +class TopLevelDesignSpace(PolymorphicSerializable['TopLevelDesignSpace'], AsynchronousObject): + """A top-level Citrine Design Space describes the set of materials that can be made. + + Abstract type that returns the proper type given a serialized dict. + + """ + + uid = properties.Optional(properties.UUID, 'id', serializable=False) + """:Optional[UUID]: Citrine Platform unique identifier""" + name = properties.String('data.name') + description = properties.Optional(properties.String(), 'data.description') + + locked_by = properties.Optional(properties.UUID, 'metadata.locked.user', + serializable=False) + """:Optional[UUID]: id of the user whose action cause the design space to + be locked, if it is locked""" + lock_time = properties.Optional(properties.Datetime, 'metadata.locked.time', + serializable=False) + """:Optional[datetime]: date and time at which the resource was locked, + if it is locked""" + + @staticmethod + def wrap_instance(subspace_data: dict) -> dict: + """Insert a serialized embedded design space into an entity envelope. + + This facilitates deserialization. + """ + return { + "data": { + "name": subspace_data.get("name", ""), + "description": subspace_data.get("description", ""), + "instance": subspace_data + } + } + + _response_key = None + _project_id: Optional[UUID] = None + _session: Optional[Session] = None + _in_progress_statuses = ["VALIDATING", "CREATED"] + _succeeded_statuses = ["READY"] + _failed_statuses = ["INVALID", "ERROR"] + + @classmethod + def get_type(cls, data) -> Type[Serializable]: + """Return the subtype.""" + from .product_design_space import ProductDesignSpace + from .hierarchical_design_space import HierarchicalDesignSpace + + return { + 'ProductDesignSpace': ProductDesignSpace, + 'HierarchicalDesignSpace': HierarchicalDesignSpace + }[data['data']['instance']['type']] + + @property + def is_locked(self) -> bool: + """If is_locked is true, edits to the design space will be rejected.""" + return self.locked_by is not None + + @property + def sample_design_space_executions(self): + """Start a Sample Design Space Execution using the current Design Space.""" + return SampleDesignSpaceExecutionCollection( + project_id=self._project_id, design_space_id=self.uid, session=self._session + ) diff --git a/src/citrine/resources/design_space.py b/src/citrine/resources/design_space.py index fb1aa5f69..bbef4cf22 100644 --- a/src/citrine/resources/design_space.py +++ b/src/citrine/resources/design_space.py @@ -1,21 +1,17 @@ """Resources that represent collections of design spaces.""" -import warnings from functools import partial -from typing import Iterable, Iterator, Optional, TypeVar, Union +from typing import Iterable, Optional, Union from uuid import UUID from citrine._utils.functions import format_escaped_url -from citrine.informatics.design_spaces import DataSourceDesignSpace, DefaultDesignSpaceMode, \ - DesignSpace, DesignSpaceSettings, EnumeratedDesignSpace, FormulationDesignSpace, \ - HierarchicalDesignSpace +from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpaceSettings, \ + HierarchicalDesignSpace, TopLevelDesignSpace from citrine._rest.collection import Collection from citrine._session import Session -CreationType = TypeVar('CreationType', bound=DesignSpace) - -class DesignSpaceCollection(Collection[DesignSpace]): +class DesignSpaceCollection(Collection[TopLevelDesignSpace]): """Represents the collection of design spaces as well as the resources belonging to it. Parameters @@ -28,7 +24,7 @@ class DesignSpaceCollection(Collection[DesignSpace]): _api_version = 'v3' _path_template = '/projects/{project_id}/design-spaces' _individual_key = None - _resource = DesignSpace + _resource = TopLevelDesignSpace _collection_key = 'response' _enumerated_cell_limit = 128 * 2000 @@ -36,67 +32,15 @@ def __init__(self, project_id: UUID, session: Session): self.project_id = project_id self.session: Session = session - def build(self, data: dict) -> DesignSpace: + def build(self, data: dict) -> TopLevelDesignSpace: """Build an individual design space.""" - design_space: DesignSpace = DesignSpace.build(data) + design_space: TopLevelDesignSpace = TopLevelDesignSpace.build(data) design_space._session = self.session design_space._project_id = self.project_id return design_space - def _verify_write_request(self, design_space: DesignSpace): - """Perform write-time validations of the design space registration or update. - - EnumeratedDesignSpaces can be pretty big, so we want to return a helpful error message - rather than let the POST or PUT call fail because the request body is too big. This - validation is performed when the design space is sent to the platform in case a user - creates a large intermediate design space but then filters it down before registering it. - - Additionally, checks for deprecated top-level design space types, and emits deprecation - warnings as appropriate. - """ - if isinstance(design_space, EnumeratedDesignSpace): - warnings.warn("As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " - "ProductDesignSpace containing a DataSourceDesignSpace subspace. " - "Support for EnumeratedDesignSpace will be dropped in 4.0.", - DeprecationWarning) - - width = len(design_space.descriptors) - length = len(design_space.data) - if width * length > self._enumerated_cell_limit: - msg = "EnumeratedDesignSpace only supports up to {} descriptor-values, " \ - "but {} were given. Please reduce the number of descriptors or candidates " \ - "in this EnumeratedDesignSpace" - raise ValueError(msg.format(self._enumerated_cell_limit, width * length)) - elif isinstance(design_space, (DataSourceDesignSpace, FormulationDesignSpace)): - typ = type(design_space).__name__ - warnings.warn(f"As of 3.27.0, saving a top-level {typ} is deprecated. Support " - "will be removed in 4.0. Wrap it in a ProductDesignSpace instead: " - f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", - DeprecationWarning) - - def _verify_read_request(self, design_space: DesignSpace): - """Perform read-time validations of the design space. - - Checks for deprecated top-level design space types, and emits deprecation warnings as - appropriate. - """ - if isinstance(design_space, EnumeratedDesignSpace): - warnings.warn("As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " - "ProductDesignSpace containing a DataSourceDesignSpace subspace. " - "Support for EnumeratedDesignSpace will be dropped in 4.0.", - DeprecationWarning) - elif isinstance(design_space, (DataSourceDesignSpace, FormulationDesignSpace)): - typ = type(design_space).__name__ - warnings.warn(f"As of 3.27.0, top-level {typ}s are deprecated. Any that remain when " - "SDK 4.0 are released will be wrapped in a ProductDesignSpace. You " - "can wrap it yourself to get rid of this warning now: " - f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", - DeprecationWarning) - - def register(self, design_space: DesignSpace) -> DesignSpace: + def register(self, design_space: TopLevelDesignSpace) -> TopLevelDesignSpace: """Create a new design space.""" - self._verify_write_request(design_space) - registered_ds = super().register(design_space) # If the initial response is invalid, just return it. @@ -107,9 +51,8 @@ def register(self, design_space: DesignSpace) -> DesignSpace: else: return self._validate(registered_ds.uid) - def update(self, design_space: DesignSpace) -> DesignSpace: + def update(self, design_space: TopLevelDesignSpace) -> TopLevelDesignSpace: """Update and validate an existing DesignSpace.""" - self._verify_write_request(design_space) updated_ds = super().update(design_space) # If the initial response is invalid, just return it. @@ -120,12 +63,12 @@ def update(self, design_space: DesignSpace) -> DesignSpace: else: return self._validate(updated_ds.uid) - def _validate(self, uid: Union[UUID, str]) -> DesignSpace: + def _validate(self, uid: Union[UUID, str]) -> TopLevelDesignSpace: path = self._get_path(uid, action="validate") entity = self.session.put_resource(path, {}, version=self._api_version) return self.build(entity) - def archive(self, uid: Union[UUID, str]) -> DesignSpace: + def archive(self, uid: Union[UUID, str]) -> TopLevelDesignSpace: """Archiving a design space removes it from view, but is not a hard delete. Parameters @@ -138,7 +81,7 @@ def archive(self, uid: Union[UUID, str]) -> DesignSpace: entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def restore(self, uid: Union[UUID, str]) -> DesignSpace: + def restore(self, uid: Union[UUID, str]) -> TopLevelDesignSpace: """Restore an archived design space. Parameters @@ -151,31 +94,6 @@ def restore(self, uid: Union[UUID, str]) -> DesignSpace: entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def get(self, uid: Union[UUID, str]) -> DesignSpace: - """Get a particular element of the collection.""" - design_space = super().get(uid) - self._verify_read_request(design_space) - return design_space - - def _build_collection_elements(self, collection: Iterable[dict]) -> Iterator[DesignSpace]: - """ - For each element in the collection, build the appropriate resource type. - - Parameters - --------- - collection: Iterable[dict] - collection containing the elements to be built - - Returns - ------- - Iterator[DesignSpace] - Resources in this collection. - - """ - for design_space in super()._build_collection_elements(collection=collection): - self._verify_read_request(design_space) - yield design_space - def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None): filters = {} if archived is not None: @@ -186,15 +104,15 @@ def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None): collection_builder=self._build_collection_elements, per_page=per_page) - def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]: + def list_all(self, *, per_page: int = 20) -> Iterable[TopLevelDesignSpace]: """List all design spaces.""" return self._list_base(per_page=per_page) - def list(self, *, per_page: int = 20) -> Iterable[DesignSpace]: + def list(self, *, per_page: int = 20) -> Iterable[TopLevelDesignSpace]: """List non-archived design spaces.""" return self._list_base(per_page=per_page, archived=False) - def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]: + def list_archived(self, *, per_page: int = 20) -> Iterable[TopLevelDesignSpace]: """List archived design spaces.""" return self._list_base(per_page=per_page, archived=True) @@ -206,7 +124,7 @@ def create_default(self, include_ingredient_fraction_constraints: bool = False, include_label_fraction_constraints: bool = False, include_label_count_constraints: bool = False, - include_parameter_constraints: bool = False) -> DesignSpace: + include_parameter_constraints: bool = False) -> TopLevelDesignSpace: """Create a default design space for a predictor. This method will return an unregistered design space for all inputs @@ -250,7 +168,7 @@ def create_default(self, Returns ------- - DesignSpace + TopLevelDesignSpace Default design space """ @@ -266,7 +184,7 @@ def create_default(self, ) data = self.session.post_resource(path, json=settings.dump(), version=self._api_version) - ds = self.build(DesignSpace.wrap_instance(data["instance"])) + ds = self.build(TopLevelDesignSpace.wrap_instance(data["instance"])) ds._settings = settings return ds @@ -311,7 +229,7 @@ def convert_to_hierarchical( if predictor_version: payload["predictor_version"] = predictor_version data = self.session.post_resource(path, json=payload, version=self._api_version) - return HierarchicalDesignSpace.build(DesignSpace.wrap_instance(data["instance"])) + return HierarchicalDesignSpace.build(TopLevelDesignSpace.wrap_instance(data["instance"])) def delete(self, uid: Union[UUID, str]): """Design Spaces cannot be deleted at this time.""" diff --git a/tests/conftest.py b/tests/conftest.py index 8c1c103ee..bbcba1583 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,65 +116,6 @@ def valid_product_design_space_data(): ) -@pytest.fixture -def valid_enumerated_design_space_data(): - """Produce valid enumerated design space data.""" - user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' - return dict( - id=str(uuid.uuid4()), - data=dict( - name='my enumerated design space', - description='enumerates some things', - instance=dict( - type='EnumeratedDesignSpace', - name='my enumerated design space', - description='enumerates some things', - descriptors=[ - dict( - type='Real', - descriptor_key='x', - units='', - lower_bound=1.0, - upper_bound=2.0, - ), - dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], - ), - dict( - type='Inorganic', - descriptor_key='formula' - ) - ], - data=[ - dict(x='1', color='red', formula='C44H54Si2'), - dict(x='2.0', color='green', formula='V2O3') - ] - ) - ), - metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) - ) - - @pytest.fixture def valid_formulation_design_space_data(): """Produce valid formulation design space data.""" @@ -182,42 +123,15 @@ def valid_formulation_design_space_data(): from citrine.informatics.descriptors import FormulationDescriptor descriptor = FormulationDescriptor.hierarchical() constraint = IngredientCountConstraint(formulation_descriptor=descriptor, min=0, max=1) - user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' return dict( - id=str(uuid.uuid4()), - data=dict( - name='formulation design space', - description='formulates some things', - instance=dict( - type='FormulationDesignSpace', - name='formulation design space', - description='formulates some things', - formulation_descriptor=descriptor.dump(), - ingredients=['foo'], - labels={'bar': ['foo']}, - constraints=[constraint.dump()], - resolution=0.1 - ) - ), - metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + type='FormulationDesignSpace', + name='formulation design space', + description='formulates some things', + formulation_descriptor=descriptor.dump(), + ingredients=['foo'], + labels={'bar': ['foo']}, + constraints=[constraint.dump()], + resolution=0.1 ) @@ -297,7 +211,7 @@ def valid_material_node_definition_data(valid_formulation_design_space_data): list=['red'] ) ], - formulation=valid_formulation_design_space_data["data"]["instance"], + formulation=valid_formulation_design_space_data, template=dict( material_template=str(uuid.uuid4()), process_template=str(uuid.uuid4()), @@ -581,34 +495,11 @@ def valid_ingredient_fractions_predictor_data(): @pytest.fixture def valid_data_source_design_space_dict(valid_gem_data_source_dict): - user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' return dict( - id=str(uuid.uuid4()), - data=dict( - name="Example valid data source design space", - description="Example valid data source design space based on a GEM Table Data Source.", - instance=dict( - type="DataSourceDesignSpace", - name="Example valid data source design space", - description="Example valid data source design space based on a GEM Table Data Source.", - data_source=valid_gem_data_source_dict - ) - ), - metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + type="DataSourceDesignSpace", + name="Example valid data source design space", + description="Example valid data source design space based on a GEM Table Data Source.", + data_source=valid_gem_data_source_dict ) @@ -652,6 +543,18 @@ def invalid_graph_predictor_data(): ) +@pytest.fixture +def invalid_design_subspace_data(): + """Produce invalid valid data used for tests.""" + return dict( + type='invalid', + name='my design space', + description='does some things', + subspaces=[], + dimensions=[] + ) + + @pytest.fixture def valid_simple_mixture_predictor_data(): """Produce valid data used for tests.""" diff --git a/tests/informatics/test_design_spaces.py b/tests/informatics/test_design_spaces.py index 9f92188f7..691f97d3b 100644 --- a/tests/informatics/test_design_spaces.py +++ b/tests/informatics/test_design_spaces.py @@ -26,15 +26,6 @@ def product_design_space() -> ProductDesignSpace: return ProductDesignSpace(name='my design space', description='does some things', dimensions=dimensions) -@pytest.fixture -def enumerated_design_space() -> EnumeratedDesignSpace: - """Build an EnumeratedDesignSpace for testing.""" - x = RealDescriptor('x', lower_bound=0.0, upper_bound=1.0, units='') - color = CategoricalDescriptor('color', categories=['r', 'g', 'b']) - data = [dict(x='0', color='r'), dict(x='1.0', color='b')] - return EnumeratedDesignSpace('enumerated', description='desc', descriptors=[x, color], data=data) - - @pytest.fixture def formulation_design_space() -> FormulationDesignSpace: desc = FormulationDescriptor.hierarchical() @@ -98,16 +89,6 @@ def test_product_initialization(product_design_space): assert product_design_space.dimensions[2].descriptor.key == 'gamma' -def test_enumerated_initialization(enumerated_design_space): - """Make sure the correct fields go to the correct places.""" - assert enumerated_design_space.name == 'enumerated' - assert enumerated_design_space.description == 'desc' - assert len(enumerated_design_space.descriptors) == 2 - assert enumerated_design_space.descriptors[0].key == 'x' - assert enumerated_design_space.descriptors[1].key == 'color' - assert enumerated_design_space.data == [{'x': '0', 'color': 'r'}, {'x': '1.0', 'color': 'b'}] - - def test_hierarchical_initialization(hierarchical_design_space): assert hierarchical_design_space.root.formulation_subspace is not None assert len(hierarchical_design_space.subspaces) == 1 @@ -121,19 +102,19 @@ def test_hierarchical_initialization(hierarchical_design_space): def test_data_source_build(valid_data_source_design_space_dict): - ds = DesignSpace.build(valid_data_source_design_space_dict) - assert ds.name == valid_data_source_design_space_dict["data"]["instance"]["name"] - assert ds.description == valid_data_source_design_space_dict["data"]["description"] - assert ds.data_source == DataSource.build(valid_data_source_design_space_dict["data"]["instance"]["data_source"]) + ds = DataSourceDesignSpace.build(valid_data_source_design_space_dict) + assert ds.name == valid_data_source_design_space_dict["name"] + assert ds.description == valid_data_source_design_space_dict["description"] + assert ds.data_source == DataSource.build(valid_data_source_design_space_dict["data_source"]) assert str(ds) == f"" def test_data_source_initialization(valid_data_source_design_space_dict): - data = valid_data_source_design_space_dict["data"] - data_source = DataSource.build(data["instance"]["data_source"]) - ds = DataSourceDesignSpace(name=data["instance"]["name"], + data = valid_data_source_design_space_dict + data_source = DataSource.build(data["data_source"]) + ds = DataSourceDesignSpace(name=data["name"], description=data["description"], data_source=data_source) - assert ds.name == data["instance"]["name"] + assert ds.name == data["name"] assert ds.description == data["description"] - assert ds.data_source.dump() == data["instance"]["data_source"] + assert ds.data_source.dump() == data["data_source"] diff --git a/tests/informatics/test_informatics.py b/tests/informatics/test_informatics.py index 07f8af2d6..0e17e698d 100644 --- a/tests/informatics/test_informatics.py +++ b/tests/informatics/test_informatics.py @@ -4,7 +4,7 @@ from citrine.informatics.constraints import ScalarRangeConstraint, AcceptableCategoriesConstraint, \ IngredientCountConstraint, IngredientFractionConstraint, IngredientRatioConstraint, \ LabelFractionConstraint, IntegerRangeConstraint -from citrine.informatics.design_spaces import ProductDesignSpace, EnumeratedDesignSpace, FormulationDesignSpace +from citrine.informatics.design_spaces import ProductDesignSpace, FormulationDesignSpace from citrine.informatics.objectives import ScalarMaxObjective, ScalarMinObjective from citrine.informatics.scores import LIScore, EIScore, EVScore @@ -31,7 +31,6 @@ (IngredientRatioConstraint(formulation_descriptor=FormulationDescriptor('Flat Formulation'), min=0.0, max=1.0, ingredient=("x", 1.5), label=("x'", 0.5), basis_ingredients=["y", "z"], basis_labels=["y'", "z'"]), ""), (ProductDesignSpace(name='my design space', description='does some things'), ""), - (EnumeratedDesignSpace('enumerated', description='desc', descriptors=[], data=[]), ""), (FormulationDesignSpace( name='Formulation', description='desc', diff --git a/tests/resources/test_design_space.py b/tests/resources/test_design_space.py index f56f3ea85..50ace6aa1 100644 --- a/tests/resources/test_design_space.py +++ b/tests/resources/test_design_space.py @@ -8,8 +8,8 @@ from citrine.exceptions import ModuleRegistrationFailedException, NotFound from citrine.informatics.descriptors import RealDescriptor, FormulationKey -from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpace, \ - DesignSpaceSettings, EnumeratedDesignSpace, HierarchicalDesignSpace, ProductDesignSpace +from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpaceSettings, \ + DesignSubspace, HierarchicalDesignSpace, ProductDesignSpace, TopLevelDesignSpace from citrine.resources.design_space import DesignSpaceCollection from citrine.resources.status_detail import StatusDetail, StatusLevelEnum from tests.utils.session import FakeCall, FakeSession @@ -46,7 +46,7 @@ def _ds_to_response(ds, status="CREATED"): @pytest.fixture def valid_product_design_space(valid_product_design_space_data) -> ProductDesignSpace: data = deepcopy(valid_product_design_space_data) - return DesignSpace.build(data) + return TopLevelDesignSpace.build(data) def test_design_space_build(valid_product_design_space_data): @@ -80,8 +80,7 @@ def test_design_space_build_with_status_detail(valid_product_design_space_data): def test_formulation_build(valid_formulation_design_space_data): - pc = DesignSpaceCollection(uuid.uuid4(), None) - design_space = pc.build(valid_formulation_design_space_data) + design_space = DesignSubspace.build(valid_formulation_design_space_data) assert design_space.name == 'formulation design space' assert design_space.description == 'formulates some things' assert design_space.formulation_descriptor.key == FormulationKey.HIERARCHICAL.value @@ -91,6 +90,10 @@ def test_formulation_build(valid_formulation_design_space_data): assert design_space.resolution == 0.1 +def test_unsupported_subspace_type(): + pass + + def test_hierarchical_build(valid_hierarchical_design_space_data): dc = DesignSpaceCollection(uuid.uuid4(), None) hds = dc.build(valid_hierarchical_design_space_data) @@ -130,62 +133,6 @@ def test_convert_to_hierarchical(valid_hierarchical_design_space_data): assert session.last_call == expected_call -def test_design_space_limits(): - """Test that the validation logic is triggered before post/put-ing enumerated design spaces.""" - # Given - session = FakeSession() - collection = DesignSpaceCollection(uuid.uuid4(), session) - - descriptors = [RealDescriptor(f"R-{i}", lower_bound=0, upper_bound=1, units="") for i in range(128)] - descriptor_values = {f"R-{i}": str(random.random()) for i in range(128)} - - just_right = EnumeratedDesignSpace( - "just right", - description="just right desc", - descriptors=descriptors, - data=[descriptor_values] * 2000 - ) - - too_big = EnumeratedDesignSpace( - "too big", - description="too big desc", - descriptors=just_right.descriptors, - data=[descriptor_values] * 2001 - ) - - # create mock post response by setting the status. - # Deserializing that huge dict takes a long time, and it's done twice when making a call to - # register or update (the second is the automatic validation kick-off). Since we're only - # interested in checking the validation pre-request, we can specify a tiny response to speed up - # the test execution. - dummy_desc = descriptors[0] - dummy_resp = EnumeratedDesignSpace( - "basic", - description="basic desc", - descriptors=[dummy_desc], - data=[{dummy_desc.key: descriptor_values[dummy_desc.key]}] - ) - mock_response = _ds_to_response(dummy_resp, status="READY") - session.responses.append(mock_response) - - # Then - with pytest.deprecated_call(): - with pytest.raises(ValueError) as excinfo: - collection.register(too_big) - assert "only supports" in str(excinfo.value) - - # test register - with pytest.deprecated_call(): - collection.register(just_right) - - # add back the response for the next test - session.responses.append(mock_response) - - # test update - with pytest.deprecated_call(): - collection.update(just_right) - - @pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) def test_create_default(predictor_version, valid_product_design_space): session = FakeSession() @@ -437,6 +384,24 @@ def test_failed_register(valid_product_design_space_data): assert retval.dump() == ds.dump() +def test_update(valid_product_design_space_data): + response_data = deepcopy(valid_product_design_space_data) + + session = FakeSession() + session.set_response(response_data) + dsc = DesignSpaceCollection(uuid.uuid4(), session) + ds = dsc.build(deepcopy(valid_product_design_space_data)) + + retval = dsc.update(ds) + + base_path = f"/projects/{dsc.project_id}/design-spaces" + assert session.calls == [ + FakeCall(method='PUT', path=f'{base_path}/{ds.uid}', json=ds.dump()), + FakeCall(method='PUT', path=f'{base_path}/{ds.uid}/validate', json={}) + ] + assert retval.dump() == ds.dump() + + def test_failed_update(valid_product_design_space_data): response_data = deepcopy(valid_product_design_space_data) response_data['metadata']['status']['name'] = 'INVALID' @@ -566,28 +531,3 @@ def test_locked(valid_product_design_space_data): assert ds.is_locked assert ds.locked_by == lock_user assert ds.lock_time == lock_time - - -@pytest.mark.parametrize("ds_data_fixture_name", ("valid_formulation_design_space_data", - "valid_enumerated_design_space_data", - "valid_data_source_design_space_dict")) -def test_deprecated_top_level_design_spaces(request, ds_data_fixture_name): - ds_data = request.getfixturevalue(ds_data_fixture_name) - - session = FakeSession() - session.set_response(ds_data) - dc = DesignSpaceCollection(uuid.uuid4(), session) - - with pytest.deprecated_call(): - ds = dc.get(uuid.uuid4()) - - with pytest.deprecated_call(): - dc.register(ds) - - with pytest.deprecated_call(): - dc.update(ds) - - session.set_response({"response": [ds_data]}) - - with pytest.deprecated_call(): - next(dc.list()) diff --git a/tests/resources/test_sample_design_space.py b/tests/resources/test_sample_design_space.py index 005b7aa33..c27ecb0da 100644 --- a/tests/resources/test_sample_design_space.py +++ b/tests/resources/test_sample_design_space.py @@ -1,7 +1,7 @@ import pytest import uuid -from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.top_level_design_space import TopLevelDesignSpace from citrine.informatics.design_spaces.sample_design_space import SampleDesignSpaceInput from citrine.informatics.executions.sample_design_space_execution import SampleDesignSpaceExecution from citrine.resources.sample_design_space_execution import SampleDesignSpaceExecutionCollection @@ -15,18 +15,13 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> SampleDesignSpaceExecutionCollection: - ds = DesignSpace() + ds = TopLevelDesignSpace() ds._project_id = uuid.uuid4() ds.uid = uuid.uuid4() ds._session = session return ds.sample_design_space_executions -@pytest.fixture -def design_space() -> DesignSpace: - return - - @pytest.fixture def sample_design_space_execution(collection: SampleDesignSpaceExecutionCollection, sample_design_space_execution_dict) -> SampleDesignSpaceExecution: return collection.build(sample_design_space_execution_dict) diff --git a/tests/serialization/__init__.py b/tests/serialization/__init__.py index 614ec6eae..ac519d948 100644 --- a/tests/serialization/__init__.py +++ b/tests/serialization/__init__.py @@ -16,7 +16,7 @@ def design_space_serialization_check(data, moduleClass): """ module = moduleClass.build(data) serialized = module.dump() - assert serialized == valid_serialization_output(data)['data'] + assert serialized == valid_serialization_output(data) def predictor_serialization_check(json, module_class): diff --git a/tests/serialization/test_design_spaces.py b/tests/serialization/test_design_spaces.py index 64f085a21..67b64363e 100644 --- a/tests/serialization/test_design_spaces.py +++ b/tests/serialization/test_design_spaces.py @@ -8,14 +8,13 @@ from citrine.informatics.constraints import IngredientCountConstraint from citrine.informatics.descriptors import CategoricalDescriptor, RealDescriptor, ChemicalFormulaDescriptor,\ FormulationDescriptor -from citrine.informatics.design_spaces import DesignSpace, ProductDesignSpace, EnumeratedDesignSpace,\ - FormulationDesignSpace +from citrine.informatics.design_spaces import DesignSpace, DesignSubspace, FormulationDesignSpace, ProductDesignSpace, TopLevelDesignSpace from citrine.informatics.dimensions import ContinuousDimension, EnumeratedDimension def test_product_deserialization(valid_product_design_space_data): """Ensure that a deserialized ProductDesignSpace looks sane.""" - for designSpaceClass in [ProductDesignSpace, DesignSpace]: + for designSpaceClass in [ProductDesignSpace, TopLevelDesignSpace]: data = deepcopy(valid_product_design_space_data) design_space: ProductDesignSpace = designSpaceClass.build(data) assert design_space.name == 'my design space' @@ -25,9 +24,7 @@ def test_product_deserialization(valid_product_design_space_data): assert type(design_space.dimensions[1]) == EnumeratedDimension assert design_space.dimensions[1].values == ['red'] assert type(design_space.subspaces[0]) == FormulationDesignSpace - assert design_space.subspaces[0].uid is None assert type(design_space.subspaces[1]) == FormulationDesignSpace - assert design_space.subspaces[1].uid is None assert design_space.subspaces[1].ingredients == {'baz'} @@ -41,47 +38,10 @@ def test_product_serialization(valid_product_design_space_data): assert serialized['instance']['subspaces'][1] == original_data['data']['instance']['subspaces'][1] -def test_enumerated_deserialization(valid_enumerated_design_space_data): - """Ensure that a deserialized EnumeratedDesignSpace looks sane. - Deserialization is done both directly (using EnumeratedDesignSpace) - and polymorphically (using DesignSpace) - """ - for designSpaceClass in [DesignSpace, EnumeratedDesignSpace]: - design_space: EnumeratedDesignSpace = designSpaceClass.build(valid_enumerated_design_space_data) - assert design_space.name == 'my enumerated design space' - assert design_space.description == 'enumerates some things' - - assert len(design_space.descriptors) == 3 - - real, categorical, formula = design_space.descriptors - - assert type(real) == RealDescriptor - assert real.key == 'x' - assert real.units == '' - assert real.lower_bound == 1.0 - assert real.upper_bound == 2.0 - - assert type(categorical) == CategoricalDescriptor - assert categorical.key == 'color' - assert categorical.categories == {'red', 'green', 'blue'} - - assert type(formula) == ChemicalFormulaDescriptor - assert formula.key == 'formula' - - assert len(design_space.data) == 2 - assert design_space.data[0] == {'x': '1', 'color': 'red', 'formula': 'C44H54Si2'} - assert design_space.data[1] == {'x': '2.0', 'color': 'green', 'formula': 'V2O3'} - - -def test_enumerated_serialization(valid_enumerated_design_space_data): - """Ensure that a serialized EnumeratedDesignSpace looks sane.""" - design_space_serialization_check(valid_enumerated_design_space_data, EnumeratedDesignSpace) - - def test_formulation_deserialization(valid_formulation_design_space_data): """Ensure that a deserialized FormulationDesignSpace looks sane. Deserialization is done both directly (using FormulationDesignSpace) - and polymorphically (using DesignSpace) + and polymorphically (using DesignSubspace) """ expected_descriptor = FormulationDescriptor.hierarchical() expected_constraint = IngredientCountConstraint( @@ -89,7 +49,7 @@ def test_formulation_deserialization(valid_formulation_design_space_data): min=0, max=1 ) - for designSpaceClass in [DesignSpace, FormulationDesignSpace]: + for designSpaceClass in [DesignSubspace, FormulationDesignSpace]: design_space: FormulationDesignSpace = designSpaceClass.build(valid_formulation_design_space_data) assert design_space.name == 'formulation design space' assert design_space.description == 'formulates some things' @@ -107,3 +67,9 @@ def test_formulation_deserialization(valid_formulation_design_space_data): def test_formulation_serialization(valid_formulation_design_space_data): """Ensure that a serialized FormulationDesignSpace looks sane.""" design_space_serialization_check(valid_formulation_design_space_data, FormulationDesignSpace) + + +def test_invalid_design_subspace_type(invalid_design_subspace_data): + """Ensures we raise proper exception when an invalid type is used.""" + with pytest.raises(ValueError): + DesignSubspace.build(invalid_design_subspace_data)