Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cosmo/clients/netbox_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ class NetboxAPIClient:
def __init__(self, url, token, interprocess_shared_cache: DictProxy):
self.url = url
self.token = token
self.session = requests.Session()
self.cache = interprocess_shared_cache

def query(self, query):
r = requests.post(
r = self.session.post(
urljoin(self.url, "/graphql/"),
json={"query": query},
headers={
Expand All @@ -33,7 +34,7 @@ def query(self, query):

def _cached_get(self, url, headers):
if url not in self.cache:
self.cache[url] = requests.get(
self.cache[url] = self.session.get(
url,
headers=headers,
)
Expand Down
54 changes: 30 additions & 24 deletions cosmo/clients/netbox_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

from cosmo.clients.netbox_client import NetboxAPIClient

fetcher: NetboxAPIClient

class ParallelQuery(ABC):

def __init__(self, client: NetboxAPIClient, **kwargs):
self.client = client
def proc_init(url, token, shared_dict):
global fetcher
fetcher = NetboxAPIClient(url, token, shared_dict)


class ParallelQuery(ABC):

def __init__(self, **kwargs):
self.data_promise = None
self.kwargs = kwargs

Expand Down Expand Up @@ -71,7 +76,7 @@ def _fetch_data(self, kwargs, pool):
"""
)

return self.client.query(query_template.substitute())["data"]
return fetcher.query(query_template.substitute())["data"]

def _merge_into(self, data: dict, query_data):

Expand Down Expand Up @@ -156,7 +161,7 @@ def _fetch_data(self, kwargs, pool):
"""
)

return self.client.query(query_template.substitute())["data"]
return fetcher.query(query_template.substitute())["data"]

def _merge_into(self, data: dict, query_data):

Expand Down Expand Up @@ -263,7 +268,7 @@ def _fetch_data(self, kwargs, pool):
"""
)

return self.client.query(query_template.substitute())["data"]
return fetcher.query(query_template.substitute())["data"]

def _merge_into(self, data: dict, query_data):
return {
Expand All @@ -276,7 +281,7 @@ class StaticRouteQuery(ParallelQuery):

def _fetch_data(self, kwargs, pool):
device_list = kwargs.get("device_list")
return self.client.query_rest(
return fetcher.query_rest(
"api/plugins/routing/staticroutes/", {"device": device_list}
)

Expand Down Expand Up @@ -307,7 +312,7 @@ class IPPoolDataQuery(ParallelQuery):

def _fetch_data(self, kwargs, pool):
device_list = kwargs.get("device_list")
return self.client.query_rest(
return fetcher.query_rest(
"api/plugins/ip-pools/ippools/", {"devices": device_list}
)

Expand Down Expand Up @@ -339,7 +344,7 @@ def _merge_into(self, data: dict, query_data):
class TobagoLineMembersDataQuery(ParallelQuery):
def _fetch_data(self, kwargs, pool):
device = kwargs.get("device")
line_members = self.client.query_rest(
line_members = fetcher.query_rest(
"api/plugins/tobago/line-members/find-by-object/",
{"content_type": "dcim.device", "object_name": device},
)
Expand Down Expand Up @@ -387,7 +392,7 @@ def _merge_into(self, data: dict, query_result):
class DeviceMACQuery(ParallelQuery):
def _fetch_data(self, kwargs, pool):
device_list = kwargs.get("device_list")
return self.client.query_rest(
return fetcher.query_rest(
"api/dcim/interfaces",
{"primary_mac_address__n": "null", "device": device_list},
)
Expand All @@ -413,7 +418,7 @@ def _merge_into(self, data: dict, query_data):
class DeviceDataQuery(ParallelQuery):

def __init__(self, *args, multiple_mac_addresses=False, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.multiple_mac_addresses = multiple_mac_addresses

def _fetch_data(self, kwargs, pool):
Expand Down Expand Up @@ -615,7 +620,7 @@ def _fetch_data(self, kwargs, pool):
device=json.dumps(device),
)

query_result = self.client.query(query)
query_result = fetcher.query(query)
return query_result["data"]

def _merge_into(self, data: dict, query_data):
Expand Down Expand Up @@ -653,38 +658,39 @@ def get_data(self, device_config):
queries.extend(
[
DeviceDataQuery(
client,
device=d,
multiple_mac_addresses=self.multiple_mac_addresses,
),
(
TobagoLineMembersDataQuery(client, device=d)
TobagoLineMembersDataQuery(device=d)
if self.feature_flags["tobago"]
else TobagoLineMemberDataDummyQuery(client, device=d)
else TobagoLineMemberDataDummyQuery(device=d)
),
]
)

queries.extend(
[
L2VPNDataQuery(client, device_list=device_list),
L2VPNDataQuery(device_list=device_list),
(
StaticRouteQuery(client, device_list=device_list)
StaticRouteQuery(device_list=device_list)
if self.feature_flags["routing"]
else StaticRouteDummyQuery(client, device_list=device_list)
else StaticRouteDummyQuery(device_list=device_list)
),
DeviceMACQuery(client, device_list=device_list),
ConnectedDevicesDataQuery(client, device_list=device_list),
LoopbackDataQuery(client, device_list=device_list),
DeviceMACQuery(device_list=device_list),
ConnectedDevicesDataQuery(device_list=device_list),
LoopbackDataQuery(device_list=device_list),
(
IPPoolDataQuery(client, device_list=device_list)
IPPoolDataQuery(device_list=device_list)
if self.feature_flags["ippools"]
else IPPoolDataDummyQuery(client, device_list=device_list)
else IPPoolDataDummyQuery(device_list=device_list)
),
]
)

with manager.Pool() as pool:
with manager.Pool(
initializer=proc_init, initargs=(self.url, self.token, manager.dict())
) as pool:
data_promises = list(map(lambda x: x.fetch_data(pool), queries))

data = dict()
Expand Down
11 changes: 9 additions & 2 deletions cosmo/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
from requests import Session

from cosmo.clients.netbox_client import NetboxAPIClient


class CommonSetup:
Expand Down Expand Up @@ -85,8 +88,12 @@ def patchPostFunc(url, json, **kwargs):

return ResponseMock(200, {"data": retVal})

getMock = mocker.patch("requests.get", side_effect=patchGetFunc)
postMock = mocker.patch("requests.post", side_effect=patchPostFunc)
getMock = mocker.patch.object(
NetboxAPIClient, "session", side_effect=patchGetFunc
)
postMock = mocker.patch.object(
NetboxAPIClient, "session", side_effect=patchPostFunc
)
return [getMock, postMock]


Expand Down
Loading