From 3d786df58a3952f844785f37c944cd34abb7b1f6 Mon Sep 17 00:00:00 2001 From: ivasylenko Date: Thu, 16 Oct 2025 15:55:57 +0300 Subject: [PATCH] GC-117703: sdk adjustments --- fastapi_code_generator/__main__.py | 57 +++++++-- fastapi_code_generator/parser.py | 9 +- fastapi_code_generator/visitors/imports.py | 13 +- poetry.lock | 137 ++------------------- pyproject.toml | 2 +- 5 files changed, 82 insertions(+), 136 deletions(-) diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index 4d3e661..27d5a4b 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -1,11 +1,13 @@ import re +import json from datetime import datetime, timezone from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +from collections import defaultdict from typing import Any, Dict, List, Optional import typer -from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir +from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir, DatetimeClassType from datamodel_code_generator.format import CodeFormatter from datamodel_code_generator.imports import Import, Imports from datamodel_code_generator.model import get_data_model_types @@ -65,15 +67,38 @@ def main( python_version: PythonVersion = typer.Option( PythonVersion.PY_39.value, "--python-version", "-p" ), + capitalise_enum_members: bool = typer.Option(False, "--capitalise-enum-members"), + output_datetime_class: Optional[DatetimeClassType] = typer.Option( + DatetimeClassType.Datetime, "--output-datetime-class", + help="Specify the datetime class to use for datetime fields" + ), + allow_population_by_field_name: bool = typer.Option(False, "--allow-population-by-field-name"), + extra_template_data: str = typer.Option(None, "--extra-template-data"), + additional_imports: str = typer.Option(None, "--additional-imports"), ) -> None: input_name: str = input_file - input_text: str + input_text: str = None + + input_name = Path(input_name).expanduser().resolve() - with open(input_file, encoding=encoding) as f: - input_text = f.read() + try: + with open(input_file, encoding=encoding) as f: + input_text = f.read() + except: + pass + + if extra_template_data: + try: + with open(extra_template_data, encoding=encoding) as f: + extra_template_data = json.load(f, object_hook=lambda d: defaultdict(dict, **d)) + except Exception as exc: + print(f"could not load extra: {exc}") model_path = Path(model_file) if model_file else MODEL_PATH # pragma: no cover + if additional_imports: + additional_imports = additional_imports.split(",") + return generate_code( input_name, input_text, @@ -89,6 +114,11 @@ def main( specify_tags=specify_tags, output_model_type=output_model_type, python_version=python_version, + capitalise_enum_members=capitalise_enum_members, + output_datetime_class=output_datetime_class, + allow_population_by_field_name=allow_population_by_field_name, + extra_template_data=extra_template_data, + additional_imports=additional_imports ) @@ -101,7 +131,6 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: return reference return None - def generate_code( input_name: str, input_text: str, @@ -117,6 +146,11 @@ def generate_code( specify_tags: Optional[str] = None, output_model_type: DataModelType = DataModelType.PydanticBaseModel, python_version: PythonVersion = PythonVersion.PY_39, + capitalise_enum_members: bool = False, + output_datetime_class: Optional[DatetimeClassType] = None, + extra_template_data: defaultdict[str, dict[str, Any]] | None = None, + allow_population_by_field_name: Optional[bool] = False, + additional_imports: Optional[list[str]] = None ) -> None: if not model_path: model_path = MODEL_PATH @@ -132,8 +166,10 @@ def generate_code( custom_visitors = [] data_model_types = get_data_model_types(output_model_type, python_version) + source = input_text or input_name + parser = OpenAPIParser( - input_text, + source=source, enum_field_as_literal=enum_field_as_literal, data_model_type=data_model_types.data_model, data_model_root_type=data_model_types.root_model, @@ -142,6 +178,13 @@ def generate_code( dump_resolve_reference_action=data_model_types.dump_resolve_reference_action, custom_template_dir=model_template_dir, target_python_version=python_version, + additional_imports=additional_imports, + base_path=Path(input_name).absolute().parent, + capitalise_enum_members=capitalise_enum_members, + output_datetime_class=output_datetime_class, + extra_template_data=extra_template_data, + allow_population_by_field_name=allow_population_by_field_name, + field_extra_keys={"union_mode"} ) with chdir(output_dir): @@ -153,7 +196,7 @@ def generate_code( modules = {output_dir / model_path.with_suffix('.py'): (models, input_name)} else: modules = { - output_dir / model_path / module_name[0]: (model.body, input_name) + output_dir / model_path / Path(*module_name): (model.body, input_name) for module_name, model in models.items() } diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index f9b0678..8515e24 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -27,6 +27,7 @@ OpenAPIScope, PythonVersion, snooper_to_methods, + DatetimeClassType, ) from datamodel_code_generator.imports import Import, Imports from datamodel_code_generator.model import DataModel, DataModelFieldBase @@ -51,7 +52,7 @@ class CachedPropertyModel(BaseModel): class Config: arbitrary_types_allowed = True - keep_untouched = (cached_property,) + ignored_types = (cached_property,) class Response(BaseModel): @@ -265,6 +266,9 @@ def __init__( custom_class_name_generator: Optional[Callable[[str], str]] = None, field_extra_keys: Optional[Set[str]] = None, field_include_all_keys: bool = False, + capitalise_enum_members: bool = False, + additional_imports: List[str] = None, + output_datetime_class: Optional[DatetimeClassType] = None, ): super().__init__( source=source, @@ -282,6 +286,7 @@ def __init__( snake_case_field=snake_case_field, strip_default_none=strip_default_none, aliases=aliases, + additional_imports=additional_imports, allow_population_by_field_name=allow_population_by_field_name, apply_default_values_for_required_fields=apply_default_values_for_required_fields, force_optional_for_required_fields=force_optional_for_required_fields, @@ -304,6 +309,8 @@ def __init__( field_extra_keys=field_extra_keys, field_include_all_keys=field_include_all_keys, openapi_scopes=[OpenAPIScope.Schemas, OpenAPIScope.Paths], + capitalise_enum_members=capitalise_enum_members, + target_datetime_class=output_datetime_class ) self.operations: Dict[str, Operation] = {} self._temporary_operation: Dict[str, Any] = {} diff --git a/fastapi_code_generator/visitors/imports.py b/fastapi_code_generator/visitors/imports.py index b1b37be..a229a32 100644 --- a/fastapi_code_generator/visitors/imports.py +++ b/fastapi_code_generator/visitors/imports.py @@ -26,9 +26,10 @@ def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]: for data_type in parser.data_types: reference = _get_most_of_reference(data_type) if reference: + reference_path = _get_path_of_reference(model_path, reference) imports.append(data_type.all_imports) imports.append( - Import.from_full_path(f'.{model_path.stem}.{reference.name}') + Import.from_full_path(f'.{reference_path}.{reference.name}') ) for from_, imports_ in parser.imports_for_fastapi.items(): imports[from_].update(imports_) @@ -37,5 +38,15 @@ def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]: imports.alias.update(operation.imports.alias) return {'imports': imports} +def _get_path_of_reference(model_path, reference): + m_path = model_path.stem + try: + yaml, *_ = reference.path.split('#') + mod, *_ = yaml.split('.') + return f"{m_path}.{'.'.join(m for m in mod.split('/'))}" + except Exception: + pass + return m_path + visit: Visitor = get_imports diff --git a/poetry.lock b/poetry.lock index f202432..618dbb6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,29 +12,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[[package]] -name = "anyio" -version = "4.4.0" -description = "High level compatibility layer for multiple asynchronous event loop implementations" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, - {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, -] - -[package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} -idna = ">=2.8" -sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} - -[package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] -trio = ["trio (>=0.23)"] - [[package]] name = "argcomplete" version = "3.4.0" @@ -97,18 +74,6 @@ d = ["aiohttp (>=3.7.4) ; sys_platform != \"win32\" or implementation_name != \" jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] -[[package]] -name = "certifi" -version = "2024.7.4" -description = "Python package for providing Mozilla's CA Bundle." -optional = false -python-versions = ">=3.6" -groups = ["main"] -files = [ - {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, - {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, -] - [[package]] name = "click" version = "8.1.7" @@ -207,21 +172,18 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "datamodel-code-generator" -version = "0.30.1" +version = "0.1.dev1273+g0f67b3cc2" description = "Datamodel Code Generator" optional = false python-versions = ">=3.9" groups = ["main"] -files = [ - {file = "datamodel_code_generator-0.30.1-py3-none-any.whl", hash = "sha256:9601dfa3da8aa8d8d54e182059f78836b1768a807d5c26df798db12d4054c8f3"}, - {file = "datamodel_code_generator-0.30.1.tar.gz", hash = "sha256:d125012face4cd1eca6c9300297a1f5775a9d5ff8fc3f68d34d0944a7beea105"}, -] +files = [] +develop = false [package.dependencies] argcomplete = ">=2.10.1,<4" black = ">=19.10b0" genson = ">=1.2.1,<2" -httpx = {version = ">=0.24.1", optional = true, markers = "extra == \"http\""} inflect = ">=4.1,<8" isort = ">=4.3.21,<7" jinja2 = ">=2.10.1,<4" @@ -238,13 +200,19 @@ http = ["httpx (>=0.24.1)"] ruff = ["ruff (>=0.9.10)"] validation = ["openapi-spec-validator (>=0.2.8,<0.7)", "prance (>=0.18.2)"] +[package.source] +type = "git" +url = "https://github.com/guardicore/datamodel-code-generator.git" +reference = "centra_contracts" +resolved_reference = "0f67b3cc2329b4449cdf9c829c32d98e52edd01e" + [[package]] name = "exceptiongroup" version = "1.2.1" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["dev"] markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, @@ -281,77 +249,6 @@ files = [ {file = "genson-1.3.0.tar.gz", hash = "sha256:e02db9ac2e3fd29e65b5286f7135762e2cd8a986537c075b06fc5f1517308e37"}, ] -[[package]] -name = "h11" -version = "0.14.0" -description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, - {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, -] - -[[package]] -name = "httpcore" -version = "1.0.5" -description = "A minimal low-level HTTP client." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, -] - -[package.dependencies] -certifi = "*" -h11 = ">=0.13,<0.15" - -[package.extras] -asyncio = ["anyio (>=4.0,<5.0)"] -http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] - -[[package]] -name = "httpx" -version = "0.27.0" -description = "The next generation HTTP client." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, -] - -[package.dependencies] -anyio = "*" -certifi = "*" -httpcore = "==1.*" -idna = "*" -sniffio = "*" - -[package.extras] -brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] -http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] - -[[package]] -name = "idna" -version = "3.7" -description = "Internationalized Domain Names in Applications (IDNA)" -optional = false -python-versions = ">=3.5" -groups = ["main"] -files = [ - {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, - {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, -] - [[package]] name = "inflect" version = "5.6.2" @@ -968,18 +865,6 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] -[[package]] -name = "sniffio" -version = "1.3.1" -description = "Sniff out which async library your code is running under" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, - {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, -] - [[package]] name = "stringcase" version = "1.2.0" @@ -1067,4 +952,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.13" -content-hash = "9f5efcf36f6a18334b4c4c5a9d941971043875bfe8356a902afd23a50a62ce1f" +content-hash = "97e888715437294e71eb951f61d9ee035bebc17d2428aebbd4a06db94736998c" diff --git a/pyproject.toml b/pyproject.toml index 97f36d7..94fec7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ fastapi-codegen = "fastapi_code_generator.__main__:app" python = ">=3.9,<3.13" click = ">=8.0.0,<8.2.0" typer = {extras = ["all"], version = ">=0.2.1,<0.13.0"} -datamodel-code-generator = {extras = ["http"], version = "0.30.1"} +datamodel-code-generator = {git = "https://github.com/guardicore/datamodel-code-generator.git", rev = "centra_contracts"} stringcase = "^1.2.0" PySnooper = ">=0.4.1,<1.2.0" jinja2 = ">=2.11.2,<4.0.0"