Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions natsapi/asyncapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ def get_flat_models_from_routes(
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
replies_from_routes: Set[ModelField] = set()
requests_from_routes: Set[ModelField] = set()
messages_from_pubs: Set[ModelField] = set()
for route in routes:
if getattr(route, "include_schema", True) and isinstance(route, Request):
if route.result:
replies_from_routes.add(route.request_field)
if route.params:
replies_from_routes.add(route.reply_field)
elif getattr(route, "include_schema", True) and isinstance(route, Publish):
if route.params:
replies_from_routes.add(route.reply_field)
for pub in pubs:
messages_from_pubs.add(pub.params_field)

flat_models = get_flat_models_from_fields(
if getattr(route, "include_schema", True):
if isinstance(route, Request):
if route.result:
replies_from_routes.add(route.request_field)
if route.params:
replies_from_routes.add(route.reply_field)
elif isinstance(route, Publish):
if route.params:
replies_from_routes.add(route.reply_field)
messages_from_pubs: Set[ModelField] = {pub.params_field for pub in pubs}
return get_flat_models_from_fields(
replies_from_routes | requests_from_routes | messages_from_pubs,
known_models=set(),
)
return flat_models
Comment on lines -26 to -43
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_flat_models_from_routes refactored with the following changes:



def get_flat_response_models(r) -> List[Type[BaseModel]]:
Expand All @@ -50,15 +47,13 @@ def get_flat_response_models(r) -> List[Type[BaseModel]]:

:r Single or multiple response models
"""
if type(r) is typing._UnionGenericAlias:
return list(r.__args__)
else:
return [r]
return list(r.__args__) if type(r) is typing._UnionGenericAlias else [r]
Comment on lines -53 to +50
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_flat_response_models refactored with the following changes:



def get_asyncapi_request_operation_metadata(operation: Request) -> Dict[str, Any]:
metadata: Dict[str, Any] = {}
metadata["summary"] = operation.summary.replace("_", " ").title()
metadata: Dict[str, Any] = {
"summary": operation.summary.replace("_", " ").title()
}
Comment on lines -60 to +56
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_asyncapi_request_operation_metadata refactored with the following changes:

metadata["description"] = operation.description

if operation.tags:
Expand Down Expand Up @@ -106,8 +101,7 @@ def generate_asyncapi_publish_channel(operation: Publish, model_name_map: Dict[s


def domain_errors_schema(lower_bound: int, upper_bound: int, exceptions: List[Exception]):
schema = {}
schema["range"] = {"upper": upper_bound, "lower": lower_bound}
schema = {"range": {"upper": upper_bound, "lower": lower_bound}}
Comment on lines -109 to +104
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function domain_errors_schema refactored with the following changes:

errors = []
for exc in exceptions:
try:
Expand Down Expand Up @@ -174,8 +168,11 @@ def get_asyncapi(
servers: Optional[Dict[str, Server]] = None,
) -> Dict[str, Any]:
subjects: Dict[str, Dict[str, Any]] = {}
info = {"title": title, "version": version}
info["description"] = description if description else None
info = {
"title": title,
"version": version,
"description": description if description else None,
}
Comment on lines -177 to +175
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_asyncapi refactored with the following changes:

components: Dict[str, Dict[str, Any]] = {}

output: Dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
Expand Down Expand Up @@ -210,7 +207,7 @@ def get_asyncapi(
output["externalDocs"] = external_docs.dict() if external_docs else None
output["errors"] = domain_errors_schema(errors.lower_bound, errors.upper_bound, errors.errors) if errors else None

output["channels"] = subjects if len(subjects) > 0 else None
output["channels"] = subjects if subjects else None
output["components"] = components

return jsonable_encoder(AsyncAPI(**output), by_alias=True, exclude_none=True)
23 changes: 15 additions & 8 deletions natsapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ async def request(
"""
json_rpc_payload = JsonRPCRequest(params=params, method=method, timeout=timeout)
reply_raw = await self.nats.request(subject, json_rpc_payload.json().encode(), timeout, headers=headers)
reply = JsonRPCReply.parse_raw(reply_raw.data)
return reply
return JsonRPCReply.parse_raw(reply_raw.data)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function NatsClient.request refactored with the following changes:


async def handle_request(self, msg):
if msg.reply and msg.reply != "None":
asyncio.create_task(self._handle_request(msg), name="natsapi_" + secrets.token_hex(16))
asyncio.create_task(
self._handle_request(msg), name=f"natsapi_{secrets.token_hex(16)}"
)
else:
asyncio.create_task(self._handle_publish(msg), name="natsapi_" + secrets.token_hex(16))
asyncio.create_task(
self._handle_publish(msg), name=f"natsapi_{secrets.token_hex(16)}"
)
Comment on lines -74 to +79
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function NatsClient.handle_request refactored with the following changes:


async def _handle_publish(self, msg):
request = JsonRPCRequest.parse_raw(msg.data)
Expand Down Expand Up @@ -151,10 +154,14 @@ def _lookup_exception_handler(self, exc: Exception) -> Optional[Callable]:
The method will only return 'None' if you have a custom exception inheriting from BaseException.
But inheriting from BaseException is bad practice, and your application should crash if you do this.
"""
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None
return next(
(
self._exception_handlers[cls]
for cls in type(exc).__mro__
if cls in self._exception_handlers
),
None,
)
Comment on lines -154 to +164
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function NatsClient._lookup_exception_handler refactored with the following changes:

  • Use the built-in function next instead of a for-loop (use-next)


async def _error_cb(self, e):
logging.exception(e)
Expand Down
36 changes: 16 additions & 20 deletions natsapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,26 @@ def jsonable_encoder(
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
return [
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
return encoded_list

for item in obj
]
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)
Comment on lines -97 to +116
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function jsonable_encoder refactored with the following changes:


if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
Expand Down
3 changes: 1 addition & 2 deletions natsapi/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

def get_validation_target(e):
err_loc_tree = [str(loc_part) for loc_part in e["loc"]]
target = ".".join(err_loc_tree)
return target
return ".".join(err_loc_tree)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_validation_target refactored with the following changes:



def handle_jsonrpc_exception(exc: JsonRPCException, request: JsonRPCRequest, subject: str) -> JsonRPCError:
Expand Down
6 changes: 4 additions & 2 deletions natsapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def set_id(cls, id):
return id or uuid4()

@classmethod
def with_params(self, params: BaseModel):
return create_model("JsonRPC" + params.__name__, __base__=self, params=(params, ...))
def with_params(cls, params: BaseModel):
return create_model(
f"JsonRPC{params.__name__}", __base__=cls, params=(params, ...)
)
Comment on lines -63 to +66
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function JsonRPCRequest.with_params refactored with the following changes:

8 changes: 4 additions & 4 deletions natsapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(
self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject)
self.result = result
self.params = get_request_model(self.endpoint, subject, self.skip_validation)
reply_name = "Reply_" + self.operation_id
request_name = "Request_" + self.operation_id
reply_name = f"Reply_{self.operation_id}"
request_name = f"Request_{self.operation_id}"
Comment on lines -35 to +36
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Request.__init__ refactored with the following changes:

self.reply_field = create_field(name=reply_name, type_=self.params)
self.request_field = create_field(name=request_name, type_=self.result)

Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
self.summary = summary or get_summary(endpoint) or subject
self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject)
self.params = get_request_model(self.endpoint, subject, self.skip_validation)
reply_name = "Reply_" + self.operation_id
reply_name = f"Reply_{self.operation_id}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Publish.__init__ refactored with the following changes:

self.reply_field = create_field(name=reply_name, type_=self.params)

self.tags = tags or []
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
self.tags = tags or []
self.externalDocs = externalDocs
self.params = params
self.params_field = create_field(name="Publish_" + subject, type_=self.params)
self.params_field = create_field(name=f"Publish_{subject}", type_=self.params)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Pub.__init__ refactored with the following changes:



class SubjectRouter:
Expand Down
26 changes: 10 additions & 16 deletions natsapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def get_summary(endpoint: Callable) -> str:


def generate_operation_id_for_subject(*, summary: str, subject: str) -> str:
operation_id = summary + "_" + subject
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
return operation_id
operation_id = f"{summary}_{subject}"
return re.sub("[^0-9a-zA-Z_]", "_", operation_id)
Comment on lines -27 to +28
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function generate_operation_id_for_subject refactored with the following changes:



def create_field(
Expand Down Expand Up @@ -68,7 +67,7 @@ def get_model_definitions(
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
definitions |= m_definitions
Comment on lines -71 to +70
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_model_definitions refactored with the following changes:

try:
model_name = model_name_map[model]
except KeyError as exc:
Expand All @@ -82,7 +81,7 @@ def get_model_definitions(

def get_request_model(func: Callable, subject: str, skip_validation: bool):
parameters = collections.OrderedDict(inspect.signature(func).parameters)
name_prefix = func.__name__ if not func.__name__ == "_" else subject
name_prefix = func.__name__ if func.__name__ != "_" else subject
Comment on lines -85 to +84
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_request_model refactored with the following changes:


if skip_validation:
assert (
Expand All @@ -94,20 +93,15 @@ def get_request_model(func: Callable, subject: str, skip_validation: bool):
for i, parameter in enumerate(parameters.values()):
if i == 0:
assert parameter.name == "app", "First parameter should be named 'app'"
if parameter.annotation == Any:
continue
else:
if parameter.annotation != Any:
assert (
parameter.annotation.__name__ in valid_app_types
), f"Valid types for app are: NatsAPI, FastAPI, or Any. Got {parameter.annotation.__name__}"
continue

continue
if parameter.name in ["args", "kwargs"] and skip_validation:
continue
else:
assert parameter.annotation is not inspect._empty, f"{parameter.name} has no type"
default = ... if parameter.default is inspect._empty else parameter.default
param_fields[parameter.name] = (parameter.annotation, default)
assert parameter.annotation is not inspect._empty, f"{parameter.name} has no type"
default = ... if parameter.default is inspect._empty else parameter.default
param_fields[parameter.name] = (parameter.annotation, default)

model = create_model(f"{name_prefix}_params", **param_fields)
return model
return create_model(f"{name_prefix}_params", **param_fields)