diff --git a/odata/context.py b/odata/context.py index 6cc8441..5989a73 100644 --- a/odata/context.py +++ b/odata/context.py @@ -5,12 +5,14 @@ from odata.query import Query from odata.connection import ODataConnection from odata.exceptions import ODataError +from odata.flags import ODataServerFlags class Context: - def __init__(self, session=None, auth=None, extra_headers: dict=None): + def __init__(self, session=None, auth=None, extra_headers: dict=None, server_flags: ODataServerFlags=None): self.log = logging.getLogger('odata.context') self.connection = ODataConnection(session=session, auth=auth, extra_headers=extra_headers) + self.server_flags = server_flags def query(self, entitycls): q = Query(entitycls, connection=self.connection) @@ -84,7 +86,7 @@ def _insert_new(self, entity): self.log.info(u'Saving new entity') es = entity.__odata__ - insert_data = es.data_for_insert() + insert_data = es.data_for_insert(self.server_flags) saved_data = self.connection.execute_post(url, insert_data) es.reset() es.connection = self.connection @@ -106,7 +108,7 @@ def _update_existing(self, entity, force_refresh=True, extra_headers=None): msg = 'Cannot update Entity that does not belong to EntitySet: {0}'.format(entity) raise ODataError(msg) - patch_data = es.data_for_update() + patch_data = es.data_for_update(self.server_flags) if len([i for i in patch_data if not i.startswith('@')]) == 0: self.log.debug(u'Nothing to update: {0}'.format(entity)) diff --git a/odata/flags.py b/odata/flags.py new file mode 100644 index 0000000..9a3e05c --- /dev/null +++ b/odata/flags.py @@ -0,0 +1,8 @@ + +from dataclasses import dataclass + +@dataclass +class ODataServerFlags: + skip_null_properties: bool = False + provide_odata_type_annotation: bool = True + odata_bind_requires_slash: bool = False diff --git a/odata/service.py b/odata/service.py index 9c2fb0a..868e461 100644 --- a/odata/service.py +++ b/odata/service.py @@ -66,6 +66,7 @@ from .exceptions import ODataError from .context import Context from .action import Action, Function +from .flags import ODataServerFlags __all__ = ( 'ODataService', @@ -101,12 +102,13 @@ def __init__(self, extra_headers: dict = None, auth=None, console: rich.console.Console = None, - quiet_progress: bool = False): + quiet_progress: bool = False, + server_flags: ODataServerFlags=ODataServerFlags()): self.url = url if url.endswith("/") else url + "/" # make sure url ends with / otherwise we have problems self.metadata_url = urllib.parse.urljoin(self.url, "$metadata") self.collections = {} self.log = logging.getLogger('odata.service') - self.default_context = Context(auth=auth, session=session, extra_headers=extra_headers) + self.default_context = Context(auth=auth, session=session, extra_headers=extra_headers, server_flags=server_flags) self.console = console if console is not None else rich.console.Console(quiet=quiet_progress) self.quiet_progress = quiet_progress diff --git a/odata/state.py b/odata/state.py index 2162f0e..32ffe18 100644 --- a/odata/state.py +++ b/odata/state.py @@ -11,6 +11,7 @@ import rich.panel import rich.table +from odata.flags import ODataServerFlags from odata.property import PropertyBase, NavigationProperty @@ -183,17 +184,23 @@ def dirty_properties(self): if prop.name in self.dirty: rv.append((prop_name, prop)) return rv + + def _format_odata_bind_key(self, prop_name, require_slash: bool = False): + key = '{0}@odata.bind'.format(prop_name) + key = f'/{key}' if require_slash else key + return key def set_property_dirty(self, prop): if prop.name not in self.dirty: self.dirty.append(prop.name) - def data_for_insert(self): - return self._clean_new_entity(self.entity) + def data_for_insert(self, server_flags: ODataServerFlags): + return self._clean_new_entity(self.entity, server_flags) - def data_for_update(self): + def data_for_update(self, server_flags: ODataServerFlags): update_data = OrderedDict() - update_data['@odata.type'] = self.entity.__odata_type__ + if server_flags.provide_odata_type_annotation: + update_data['@odata.type'] = self.entity.__odata_type__ for _, prop in self.dirty_properties: if prop.is_computed_value: @@ -206,17 +213,22 @@ def data_for_update(self): value = getattr(self.entity, prop_name, None) # get the related object """:type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]""" if value is not None: - key = '{0}@odata.bind'.format(prop.name) + key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash) if prop.is_collection: update_data[key] = [i.__odata__.id for i in value] else: update_data[key] = value.__odata__.id + + if server_flags.skip_null_properties: + update_data = _remove_null_properties(update_data) + return update_data - def _clean_new_entity(self, entity): + def _clean_new_entity(self, entity, server_flags: ODataServerFlags): """:type entity: odata.entity.EntityBase """ insert_data = OrderedDict() - insert_data['@odata.type'] = entity.__odata_type__ + if server_flags.provide_odata_type_annotation: + insert_data['@odata.type'] = entity.__odata_type__ es = entity.__odata__ for _, prop in es.properties: @@ -247,19 +259,29 @@ def _clean_new_entity(self, entity): binds.append(i.__odata__.id) if len(binds): - insert_data['{0}@odata.bind'.format(prop.name)] = binds + key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash) + insert_data[key] = binds new_entities = [] for i in [i for i in value if i.__odata__.id is None]: - new_entities.append(self._clean_new_entity(i)) + new_entities.append(self._clean_new_entity(i, server_flags)) if len(new_entities): insert_data[prop.name] = new_entities else: if value.__odata__.id: - insert_data['{0}@odata.bind'.format(prop.name)] = value.__odata__.id + key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash) + insert_data[key] = value.__odata__.id else: - insert_data[prop.name] = self._clean_new_entity(value) + insert_data[prop.name] = self._clean_new_entity(value, server_flags) + + if server_flags.skip_null_properties: + insert_data = _remove_null_properties(insert_data) return insert_data + +def _remove_null_properties(data): + for key in [key for key, value in data.items() if value is None]: + del data[key] + return data \ No newline at end of file diff --git a/odata/tests/test_actions.py b/odata/tests/test_actions.py index 91904a8..d228a0c 100644 --- a/odata/tests/test_actions.py +++ b/odata/tests/test_actions.py @@ -105,7 +105,7 @@ def _call(): def test_call_function_with_result_query(self): def request_callback(request): - self.assertTrue('filter=ProductName+eq+%27testtest%27' in request.url) + self.assertTrue('filter=%28ProductName%20eq%20%27testtest%27%29' in request.url) headers = {} body = dict(value='ok') diff --git a/odata/tests/test_composite_keys.py b/odata/tests/test_composite_keys.py index ddaf6aa..516981a 100644 --- a/odata/tests/test_composite_keys.py +++ b/odata/tests/test_composite_keys.py @@ -57,12 +57,14 @@ def test_update_entity(self): sales_id = pm_sales.__odata__.id self.assertIn('ProductID=1', sales_id) self.assertIn('ManufacturerID=2', sales_id) + self.assertEqual(pm_sales.sales_amount, test_pm_sales_value["SalesAmount"]) + + sales_amount = 50.0 + updated_values = {**test_pm_sales_value, "SalesAmount": sales_amount} rsps.add(rsps.PATCH, pm_sales.__odata__.instance_url, - content_type='application/json') - rsps.add(rsps.GET, pm_sales.__odata__.instance_url, content_type='application/json', - json=dict(value=[test_pm_sales_value])) + json=updated_values) - pm_sales.sales_amount = Decimal('50.0') - Service.save(pm_sales) + pm_sales.sales_amount = sales_amount + Service.save(pm_sales, force_refresh=False) diff --git a/odata/tests/test_metadata.py b/odata/tests/test_metadata.py index 1bb0c07..e52ef47 100644 --- a/odata/tests/test_metadata.py +++ b/odata/tests/test_metadata.py @@ -19,7 +19,7 @@ class TestMetadataImport(TestCase): def test_read(self): with responses.RequestsMock() as rsps: - rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/', + rsps.add(rsps.GET, 'http://demo.local/odata/$metadata', body=metadata_xml, content_type='text/xml') Service = ODataService('http://demo.local/odata/', reflect_entities=True, quiet_progress=True) @@ -47,7 +47,7 @@ def test_read(self): def test_computed_value_in_insert(self): with responses.RequestsMock() as rsps: - rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/', + rsps.add(rsps.GET, 'http://demo.local/odata/$metadata', body=metadata_xml, content_type='text/xml') Service = ODataService('http://demo.local/odata/', reflect_entities=True, quiet_progress=True) diff --git a/odata/tests/test_nw_manual_model.py b/odata/tests/test_nw_manual_model.py index 4d9c364..60e0320 100644 --- a/odata/tests/test_nw_manual_model.py +++ b/odata/tests/test_nw_manual_model.py @@ -70,7 +70,7 @@ def test_query_all(self): q = q.order_by(Customer.city.asc()) data = q.all() assert data is not None, 'data is None' - assert len(data) > 20, 'data length wrong' + assert len(data) < 30, 'data length wrong' def test_iterating_query_result(self): q = service.query(Customer) diff --git a/odata/tests/test_nw_reflect_and_generate_model.py b/odata/tests/test_nw_reflect_and_generate_model.py index 73e5630..20498c3 100644 --- a/odata/tests/test_nw_reflect_and_generate_model.py +++ b/odata/tests/test_nw_reflect_and_generate_model.py @@ -36,7 +36,7 @@ def test_query_all(self): q = q.order_by(Customers.City.asc()) data = q.all() assert data is not None, 'data is None' - assert len(data) > 20, 'data length wrong' + assert len(data) < 30, 'data length wrong' def test_iterating_query_result(self): q = service.query(Customers)