Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions odata/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from odata.query import Query
from odata.connection import ODataConnection
from odata.exceptions import ODataError
from odata.flags import ODataServerFlags


class Context:
def __init__(self, session=None, auth=None, extra_headers: dict=None):
def __init__(self, session=None, auth=None, extra_headers: dict=None, server_flags: ODataServerFlags=None):
self.log = logging.getLogger('odata.context')
self.connection = ODataConnection(session=session, auth=auth, extra_headers=extra_headers)
self.server_flags = server_flags

def query(self, entitycls):
q = Query(entitycls, connection=self.connection)
Expand Down Expand Up @@ -84,7 +86,7 @@ def _insert_new(self, entity):
self.log.info(u'Saving new entity')

es = entity.__odata__
insert_data = es.data_for_insert()
insert_data = es.data_for_insert(self.server_flags)
saved_data = self.connection.execute_post(url, insert_data)
es.reset()
es.connection = self.connection
Expand All @@ -106,7 +108,7 @@ def _update_existing(self, entity, force_refresh=True, extra_headers=None):
msg = 'Cannot update Entity that does not belong to EntitySet: {0}'.format(entity)
raise ODataError(msg)

patch_data = es.data_for_update()
patch_data = es.data_for_update(self.server_flags)

if len([i for i in patch_data if not i.startswith('@')]) == 0:
self.log.debug(u'Nothing to update: {0}'.format(entity))
Expand Down
8 changes: 8 additions & 0 deletions odata/flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from dataclasses import dataclass

@dataclass
class ODataServerFlags:
skip_null_properties: bool = False
provide_odata_type_annotation: bool = True
odata_bind_requires_slash: bool = False
6 changes: 4 additions & 2 deletions odata/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from .exceptions import ODataError
from .context import Context
from .action import Action, Function
from .flags import ODataServerFlags

__all__ = (
'ODataService',
Expand Down Expand Up @@ -101,12 +102,13 @@ def __init__(self,
extra_headers: dict = None,
auth=None,
console: rich.console.Console = None,
quiet_progress: bool = False):
quiet_progress: bool = False,
server_flags: ODataServerFlags=ODataServerFlags()):
self.url = url if url.endswith("/") else url + "/" # make sure url ends with / otherwise we have problems
self.metadata_url = urllib.parse.urljoin(self.url, "$metadata")
self.collections = {}
self.log = logging.getLogger('odata.service')
self.default_context = Context(auth=auth, session=session, extra_headers=extra_headers)
self.default_context = Context(auth=auth, session=session, extra_headers=extra_headers, server_flags=server_flags)
self.console = console if console is not None else rich.console.Console(quiet=quiet_progress)
self.quiet_progress = quiet_progress

Expand Down
44 changes: 33 additions & 11 deletions odata/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import rich.panel
import rich.table

from odata.flags import ODataServerFlags
from odata.property import PropertyBase, NavigationProperty


Expand Down Expand Up @@ -183,17 +184,23 @@ def dirty_properties(self):
if prop.name in self.dirty:
rv.append((prop_name, prop))
return rv

def _format_odata_bind_key(self, prop_name, require_slash: bool = False):
key = '{0}@odata.bind'.format(prop_name)
key = f'/{key}' if require_slash else key
return key

def set_property_dirty(self, prop):
if prop.name not in self.dirty:
self.dirty.append(prop.name)

def data_for_insert(self):
return self._clean_new_entity(self.entity)
def data_for_insert(self, server_flags: ODataServerFlags):
return self._clean_new_entity(self.entity, server_flags)

def data_for_update(self):
def data_for_update(self, server_flags: ODataServerFlags):
update_data = OrderedDict()
update_data['@odata.type'] = self.entity.__odata_type__
if server_flags.provide_odata_type_annotation:
update_data['@odata.type'] = self.entity.__odata_type__

for _, prop in self.dirty_properties:
if prop.is_computed_value:
Expand All @@ -206,17 +213,22 @@ def data_for_update(self):
value = getattr(self.entity, prop_name, None) # get the related object
""":type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]"""
if value is not None:
key = '{0}@odata.bind'.format(prop.name)
key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash)
if prop.is_collection:
update_data[key] = [i.__odata__.id for i in value]
else:
update_data[key] = value.__odata__.id

if server_flags.skip_null_properties:
update_data = _remove_null_properties(update_data)

return update_data

def _clean_new_entity(self, entity):
def _clean_new_entity(self, entity, server_flags: ODataServerFlags):
""":type entity: odata.entity.EntityBase """
insert_data = OrderedDict()
insert_data['@odata.type'] = entity.__odata_type__
if server_flags.provide_odata_type_annotation:
insert_data['@odata.type'] = entity.__odata_type__

es = entity.__odata__
for _, prop in es.properties:
Expand Down Expand Up @@ -247,19 +259,29 @@ def _clean_new_entity(self, entity):
binds.append(i.__odata__.id)

if len(binds):
insert_data['{0}@odata.bind'.format(prop.name)] = binds
key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash)
insert_data[key] = binds

new_entities = []
for i in [i for i in value if i.__odata__.id is None]:
new_entities.append(self._clean_new_entity(i))
new_entities.append(self._clean_new_entity(i, server_flags))

if len(new_entities):
insert_data[prop.name] = new_entities

else:
if value.__odata__.id:
insert_data['{0}@odata.bind'.format(prop.name)] = value.__odata__.id
key = self._format_odata_bind_key(prop.name, server_flags.odata_bind_requires_slash)
insert_data[key] = value.__odata__.id
else:
insert_data[prop.name] = self._clean_new_entity(value)
insert_data[prop.name] = self._clean_new_entity(value, server_flags)

if server_flags.skip_null_properties:
insert_data = _remove_null_properties(insert_data)

return insert_data

def _remove_null_properties(data):
for key in [key for key, value in data.items() if value is None]:
del data[key]
return data
2 changes: 1 addition & 1 deletion odata/tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _call():

def test_call_function_with_result_query(self):
def request_callback(request):
self.assertTrue('filter=ProductName+eq+%27testtest%27' in request.url)
self.assertTrue('filter=%28ProductName%20eq%20%27testtest%27%29' in request.url)

headers = {}
body = dict(value='ok')
Expand Down
12 changes: 7 additions & 5 deletions odata/tests/test_composite_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def test_update_entity(self):
sales_id = pm_sales.__odata__.id
self.assertIn('ProductID=1', sales_id)
self.assertIn('ManufacturerID=2', sales_id)
self.assertEqual(pm_sales.sales_amount, test_pm_sales_value["SalesAmount"])

sales_amount = 50.0
updated_values = {**test_pm_sales_value, "SalesAmount": sales_amount}

rsps.add(rsps.PATCH, pm_sales.__odata__.instance_url,
content_type='application/json')
rsps.add(rsps.GET, pm_sales.__odata__.instance_url,
content_type='application/json',
json=dict(value=[test_pm_sales_value]))
json=updated_values)

pm_sales.sales_amount = Decimal('50.0')
Service.save(pm_sales)
pm_sales.sales_amount = sales_amount
Service.save(pm_sales, force_refresh=False)
4 changes: 2 additions & 2 deletions odata/tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestMetadataImport(TestCase):

def test_read(self):
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/',
rsps.add(rsps.GET, 'http://demo.local/odata/$metadata',
body=metadata_xml, content_type='text/xml')
Service = ODataService('http://demo.local/odata/', reflect_entities=True, quiet_progress=True)

Expand Down Expand Up @@ -47,7 +47,7 @@ def test_read(self):

def test_computed_value_in_insert(self):
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/',
rsps.add(rsps.GET, 'http://demo.local/odata/$metadata',
body=metadata_xml, content_type='text/xml')
Service = ODataService('http://demo.local/odata/', reflect_entities=True, quiet_progress=True)

Expand Down
2 changes: 1 addition & 1 deletion odata/tests/test_nw_manual_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_query_all(self):
q = q.order_by(Customer.city.asc())
data = q.all()
assert data is not None, 'data is None'
assert len(data) > 20, 'data length wrong'
assert len(data) < 30, 'data length wrong'

def test_iterating_query_result(self):
q = service.query(Customer)
Expand Down
2 changes: 1 addition & 1 deletion odata/tests/test_nw_reflect_and_generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_query_all(self):
q = q.order_by(Customers.City.asc())
data = q.all()
assert data is not None, 'data is None'
assert len(data) > 20, 'data length wrong'
assert len(data) < 30, 'data length wrong'

def test_iterating_query_result(self):
q = service.query(Customers)
Expand Down