diff --git a/natsapi/asyncapi/utils.py b/natsapi/asyncapi/utils.py index ea14f01..ea45333 100644 --- a/natsapi/asyncapi/utils.py +++ b/natsapi/asyncapi/utils.py @@ -23,24 +23,21 @@ def get_flat_models_from_routes( ) -> Set[Union[Type[BaseModel], Type[Enum]]]: replies_from_routes: Set[ModelField] = set() requests_from_routes: Set[ModelField] = set() - messages_from_pubs: Set[ModelField] = set() for route in routes: - if getattr(route, "include_schema", True) and isinstance(route, Request): - if route.result: - replies_from_routes.add(route.request_field) - if route.params: - replies_from_routes.add(route.reply_field) - elif getattr(route, "include_schema", True) and isinstance(route, Publish): - if route.params: - replies_from_routes.add(route.reply_field) - for pub in pubs: - messages_from_pubs.add(pub.params_field) - - flat_models = get_flat_models_from_fields( + if getattr(route, "include_schema", True): + if isinstance(route, Request): + if route.result: + replies_from_routes.add(route.request_field) + if route.params: + replies_from_routes.add(route.reply_field) + elif isinstance(route, Publish): + if route.params: + replies_from_routes.add(route.reply_field) + messages_from_pubs: Set[ModelField] = {pub.params_field for pub in pubs} + return get_flat_models_from_fields( replies_from_routes | requests_from_routes | messages_from_pubs, known_models=set(), ) - return flat_models def get_flat_response_models(r) -> List[Type[BaseModel]]: @@ -50,15 +47,13 @@ def get_flat_response_models(r) -> List[Type[BaseModel]]: :r Single or multiple response models """ - if type(r) is typing._UnionGenericAlias: - return list(r.__args__) - else: - return [r] + return list(r.__args__) if type(r) is typing._UnionGenericAlias else [r] def get_asyncapi_request_operation_metadata(operation: Request) -> Dict[str, Any]: - metadata: Dict[str, Any] = {} - metadata["summary"] = operation.summary.replace("_", " ").title() + metadata: Dict[str, Any] = { + "summary": operation.summary.replace("_", " ").title() + } metadata["description"] = operation.description if operation.tags: @@ -106,8 +101,7 @@ def generate_asyncapi_publish_channel(operation: Publish, model_name_map: Dict[s def domain_errors_schema(lower_bound: int, upper_bound: int, exceptions: List[Exception]): - schema = {} - schema["range"] = {"upper": upper_bound, "lower": lower_bound} + schema = {"range": {"upper": upper_bound, "lower": lower_bound}} errors = [] for exc in exceptions: try: @@ -174,8 +168,11 @@ def get_asyncapi( servers: Optional[Dict[str, Server]] = None, ) -> Dict[str, Any]: subjects: Dict[str, Dict[str, Any]] = {} - info = {"title": title, "version": version} - info["description"] = description if description else None + info = { + "title": title, + "version": version, + "description": description if description else None, + } components: Dict[str, Dict[str, Any]] = {} output: Dict[str, Any] = {"asyncapi": asyncapi_version, "info": info} @@ -210,7 +207,7 @@ def get_asyncapi( output["externalDocs"] = external_docs.dict() if external_docs else None output["errors"] = domain_errors_schema(errors.lower_bound, errors.upper_bound, errors.errors) if errors else None - output["channels"] = subjects if len(subjects) > 0 else None + output["channels"] = subjects if subjects else None output["components"] = components return jsonable_encoder(AsyncAPI(**output), by_alias=True, exclude_none=True) diff --git a/natsapi/client/client.py b/natsapi/client/client.py index 92929ca..945e00b 100644 --- a/natsapi/client/client.py +++ b/natsapi/client/client.py @@ -66,14 +66,17 @@ async def request( """ json_rpc_payload = JsonRPCRequest(params=params, method=method, timeout=timeout) reply_raw = await self.nats.request(subject, json_rpc_payload.json().encode(), timeout, headers=headers) - reply = JsonRPCReply.parse_raw(reply_raw.data) - return reply + return JsonRPCReply.parse_raw(reply_raw.data) async def handle_request(self, msg): if msg.reply and msg.reply != "None": - asyncio.create_task(self._handle_request(msg), name="natsapi_" + secrets.token_hex(16)) + asyncio.create_task( + self._handle_request(msg), name=f"natsapi_{secrets.token_hex(16)}" + ) else: - asyncio.create_task(self._handle_publish(msg), name="natsapi_" + secrets.token_hex(16)) + asyncio.create_task( + self._handle_publish(msg), name=f"natsapi_{secrets.token_hex(16)}" + ) async def _handle_publish(self, msg): request = JsonRPCRequest.parse_raw(msg.data) @@ -151,10 +154,14 @@ def _lookup_exception_handler(self, exc: Exception) -> Optional[Callable]: The method will only return 'None' if you have a custom exception inheriting from BaseException. But inheriting from BaseException is bad practice, and your application should crash if you do this. """ - for cls in type(exc).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None + return next( + ( + self._exception_handlers[cls] + for cls in type(exc).__mro__ + if cls in self._exception_handlers + ), + None, + ) async def _error_cb(self, e): logging.exception(e) diff --git a/natsapi/encoders.py b/natsapi/encoders.py index 84bf426..6d4745f 100644 --- a/natsapi/encoders.py +++ b/natsapi/encoders.py @@ -94,30 +94,26 @@ def jsonable_encoder( encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) + return [ + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, ) - return encoded_list - + for item in obj + ] if custom_encoder: if type(obj) in custom_encoder: return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder(obj) + for encoder_type, encoder in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder(obj) if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) diff --git a/natsapi/exception_handlers.py b/natsapi/exception_handlers.py index ebe1bb5..ba4913b 100644 --- a/natsapi/exception_handlers.py +++ b/natsapi/exception_handlers.py @@ -8,8 +8,7 @@ def get_validation_target(e): err_loc_tree = [str(loc_part) for loc_part in e["loc"]] - target = ".".join(err_loc_tree) - return target + return ".".join(err_loc_tree) def handle_jsonrpc_exception(exc: JsonRPCException, request: JsonRPCRequest, subject: str) -> JsonRPCError: diff --git a/natsapi/models.py b/natsapi/models.py index 01b2c63..e60baf0 100644 --- a/natsapi/models.py +++ b/natsapi/models.py @@ -60,5 +60,7 @@ def set_id(cls, id): return id or uuid4() @classmethod - def with_params(self, params: BaseModel): - return create_model("JsonRPC" + params.__name__, __base__=self, params=(params, ...)) + def with_params(cls, params: BaseModel): + return create_model( + f"JsonRPC{params.__name__}", __base__=cls, params=(params, ...) + ) diff --git a/natsapi/routing.py b/natsapi/routing.py index 9ada06f..4b980d1 100644 --- a/natsapi/routing.py +++ b/natsapi/routing.py @@ -32,8 +32,8 @@ def __init__( self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject) self.result = result self.params = get_request_model(self.endpoint, subject, self.skip_validation) - reply_name = "Reply_" + self.operation_id - request_name = "Request_" + self.operation_id + reply_name = f"Reply_{self.operation_id}" + request_name = f"Request_{self.operation_id}" self.reply_field = create_field(name=reply_name, type_=self.params) self.request_field = create_field(name=request_name, type_=self.result) @@ -66,7 +66,7 @@ def __init__( self.summary = summary or get_summary(endpoint) or subject self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject) self.params = get_request_model(self.endpoint, subject, self.skip_validation) - reply_name = "Reply_" + self.operation_id + reply_name = f"Reply_{self.operation_id}" self.reply_field = create_field(name=reply_name, type_=self.params) self.tags = tags or [] @@ -115,7 +115,7 @@ def __init__( self.tags = tags or [] self.externalDocs = externalDocs self.params = params - self.params_field = create_field(name="Publish_" + subject, type_=self.params) + self.params_field = create_field(name=f"Publish_{subject}", type_=self.params) class SubjectRouter: diff --git a/natsapi/utils.py b/natsapi/utils.py index ee84d3f..0d2679e 100644 --- a/natsapi/utils.py +++ b/natsapi/utils.py @@ -24,9 +24,8 @@ def get_summary(endpoint: Callable) -> str: def generate_operation_id_for_subject(*, summary: str, subject: str) -> str: - operation_id = summary + "_" + subject - operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) - return operation_id + operation_id = f"{summary}_{subject}" + return re.sub("[^0-9a-zA-Z_]", "_", operation_id) def create_field( @@ -68,7 +67,7 @@ def get_model_definitions( m_schema, m_definitions, m_nested_models = model_process_schema( model, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) - definitions.update(m_definitions) + definitions |= m_definitions try: model_name = model_name_map[model] except KeyError as exc: @@ -82,7 +81,7 @@ def get_model_definitions( def get_request_model(func: Callable, subject: str, skip_validation: bool): parameters = collections.OrderedDict(inspect.signature(func).parameters) - name_prefix = func.__name__ if not func.__name__ == "_" else subject + name_prefix = func.__name__ if func.__name__ != "_" else subject if skip_validation: assert ( @@ -94,20 +93,15 @@ def get_request_model(func: Callable, subject: str, skip_validation: bool): for i, parameter in enumerate(parameters.values()): if i == 0: assert parameter.name == "app", "First parameter should be named 'app'" - if parameter.annotation == Any: - continue - else: + if parameter.annotation != Any: assert ( parameter.annotation.__name__ in valid_app_types ), f"Valid types for app are: NatsAPI, FastAPI, or Any. Got {parameter.annotation.__name__}" - continue - + continue if parameter.name in ["args", "kwargs"] and skip_validation: continue - else: - assert parameter.annotation is not inspect._empty, f"{parameter.name} has no type" - default = ... if parameter.default is inspect._empty else parameter.default - param_fields[parameter.name] = (parameter.annotation, default) + assert parameter.annotation is not inspect._empty, f"{parameter.name} has no type" + default = ... if parameter.default is inspect._empty else parameter.default + param_fields[parameter.name] = (parameter.annotation, default) - model = create_model(f"{name_prefix}_params", **param_fields) - return model + return create_model(f"{name_prefix}_params", **param_fields)