diff --git a/docs/references/index.md b/docs/references/index.md index a928966..cc14b79 100644 --- a/docs/references/index.md +++ b/docs/references/index.md @@ -28,6 +28,12 @@ Options: faster processing and better type support. Defaults to False. +--use-awaredatetime Use timezone-aware datetime objects instead of naive + datetime objects. This ensures proper handling of + timezone information in the generated models. + Only supported with Pydantic v2. + Defaults to False. + --custom-template-path TEXT Custom template path to use. Allows overriding of the built in templates. diff --git a/pyproject.toml b/pyproject.toml index 697b1bb..fb6f1d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ Changelog = "https://github.com/MarcoMuellner/openapi-python-generator/releases" [tool.poetry.dependencies] python = "^3.8" -httpx = {extras = ["all"], version = "^0.23.0"} +httpx = {extras = ["all"], version = ">=0.23.0,<1.0.0"} pydantic = "^2.10.2" orjson = "^3.9.15" Jinja2 = "^3.1.2" diff --git a/src/openapi_python_generator/__main__.py b/src/openapi_python_generator/__main__.py index f2473e5..4c264da 100644 --- a/src/openapi_python_generator/__main__.py +++ b/src/openapi_python_generator/__main__.py @@ -31,6 +31,14 @@ help="Use the orjson library to serialize the data. This is faster than the default json library and provides " "serialization of datetimes and other types that are not supported by the default json library.", ) +@click.option( + "--use-awaredatetime", + is_flag=True, + show_default=True, + default=False, + help="Use timezone-aware datetime objects instead of naive datetime objects. This ensures proper handling of " + "timezone information in the generated models.", +) @click.option( "--custom-template-path", type=str, @@ -58,6 +66,7 @@ def main( library: Optional[HTTPLibrary] = HTTPLibrary.httpx, env_token_name: Optional[str] = None, use_orjson: bool = False, + use_awaredatetime: bool = False, custom_template_path: Optional[str] = None, pydantic_version: PydanticVersion = PydanticVersion.V2, formatter: Formatter = Formatter.BLACK, @@ -69,7 +78,7 @@ def main( an OUTPUT path, where the resulting client is created. """ generate_data( - source, output, library, env_token_name, use_orjson, custom_template_path, pydantic_version, formatter + source, output, library, env_token_name, use_orjson, use_awaredatetime, custom_template_path, pydantic_version, formatter ) diff --git a/src/openapi_python_generator/generate_data.py b/src/openapi_python_generator/generate_data.py index 3bbc8d4..ecc835a 100644 --- a/src/openapi_python_generator/generate_data.py +++ b/src/openapi_python_generator/generate_data.py @@ -180,6 +180,7 @@ def generate_data( library: Optional[HTTPLibrary] = HTTPLibrary.httpx, env_token_name: Optional[str] = None, use_orjson: bool = False, + use_awaredatetime: bool = False, custom_template_path: Optional[str] = None, pydantic_version: PydanticVersion = PydanticVersion.V2, formatter: Formatter = Formatter.BLACK, @@ -195,6 +196,7 @@ def generate_data( library_config_dict[library], env_token_name, use_orjson, + use_awaredatetime, custom_template_path, pydantic_version, ) diff --git a/src/openapi_python_generator/language_converters/python/common.py b/src/openapi_python_generator/language_converters/python/common.py index e3c55c4..769c29c 100644 --- a/src/openapi_python_generator/language_converters/python/common.py +++ b/src/openapi_python_generator/language_converters/python/common.py @@ -1,10 +1,13 @@ import keyword import re from typing import Optional +from openapi_python_generator.common import PydanticVersion _use_orjson: bool = False +_pydantic_version: PydanticVersion = PydanticVersion.V2 _custom_template_path: str = None +_pydantic_use_awaredatetime: bool = False _symbol_ascii_strip_re = re.compile(r"[^A-Za-z0-9_]") @@ -16,6 +19,13 @@ def set_use_orjson(value: bool) -> None: global _use_orjson _use_orjson = value +def set_pydantic_version(value: PydanticVersion) -> None: + """ + Set the value of the global variable + :param value: value of the variable + """ + global _pydantic_version + _pydantic_version = value def get_use_orjson() -> bool: """ @@ -25,6 +35,13 @@ def get_use_orjson() -> bool: global _use_orjson return _use_orjson +def get_pydantic_version() -> PydanticVersion: + """ + Get the value of the global variable _pydantic_version. + :return: value of the variable + """ + global _pydantic_version + return _pydantic_version def set_custom_template_path(value: Optional[str]) -> None: """ @@ -44,14 +61,33 @@ def get_custom_template_path() -> Optional[str]: return _custom_template_path +def set_pydantic_use_awaredatetime(value: bool) -> None: + """ + Set whether to use AwareDateTime from pydantic instead of datetime. + :param value: value of the variable + """ + global _pydantic_use_awaredatetime + _pydantic_use_awaredatetime = value + +def get_pydantic_use_awaredatetime() -> bool: + """ + Get whether to use AwareDateTime from pydantic instead of datetime. + :return: value of the variable + """ + global _pydantic_use_awaredatetime + return _pydantic_use_awaredatetime + + def normalize_symbol(symbol: str) -> str: """ Remove invalid characters & keywords in Python symbol names :param symbol: name of the identifier :return: normalized identifier name """ - symbol = symbol.replace("-", "_") + symbol = symbol.replace("-", "_").replace(" ", "_").replace(".", "_") normalized_symbol = _symbol_ascii_strip_re.sub("", symbol) if normalized_symbol in keyword.kwlist: normalized_symbol = normalized_symbol + "_" + if len(normalized_symbol) > 0 and normalized_symbol[0].isnumeric(): + normalized_symbol = "_" + normalized_symbol return normalized_symbol diff --git a/src/openapi_python_generator/language_converters/python/generator.py b/src/openapi_python_generator/language_converters/python/generator.py index 33758da..487b55f 100644 --- a/src/openapi_python_generator/language_converters/python/generator.py +++ b/src/openapi_python_generator/language_converters/python/generator.py @@ -22,15 +22,20 @@ def generator( library_config: LibraryConfig, env_token_name: Optional[str] = None, use_orjson: bool = False, + use_awaredatetime: bool = False, custom_template_path: Optional[str] = None, pydantic_version: PydanticVersion = PydanticVersion.V2, ) -> ConversionResult: """ Generate Python code from an OpenAPI 3.0 specification. """ + if use_awaredatetime and pydantic_version != PydanticVersion.V2: + raise ValueError("Timezone-aware datetime is only supported with Pydantic v2. Please use --pydantic-version v2.") common.set_use_orjson(use_orjson) common.set_custom_template_path(custom_template_path) + common.set_pydantic_version(pydantic_version) + common.set_pydantic_use_awaredatetime(use_awaredatetime) if data.components is not None: models = generate_models(data.components, pydantic_version) diff --git a/src/openapi_python_generator/language_converters/python/jinja_config.py b/src/openapi_python_generator/language_converters/python/jinja_config.py index 25505f7..22b1f61 100644 --- a/src/openapi_python_generator/language_converters/python/jinja_config.py +++ b/src/openapi_python_generator/language_converters/python/jinja_config.py @@ -19,7 +19,7 @@ def create_jinja_env(): custom_template_path = common.get_custom_template_path() - return Environment( + environment = Environment( loader=( ChoiceLoader( [ @@ -33,3 +33,7 @@ def create_jinja_env(): autoescape=True, trim_blocks=True, ) + + environment.filters["normalize_symbol"] = common.normalize_symbol + + return environment diff --git a/src/openapi_python_generator/language_converters/python/model_generator.py b/src/openapi_python_generator/language_converters/python/model_generator.py index 94bf647..2da3e44 100644 --- a/src/openapi_python_generator/language_converters/python/model_generator.py +++ b/src/openapi_python_generator/language_converters/python/model_generator.py @@ -1,5 +1,4 @@ import itertools -import re from typing import List from typing import Optional @@ -20,6 +19,7 @@ from openapi_python_generator.models import Model from openapi_python_generator.models import Property from openapi_python_generator.models import TypeConversion +from openapi_python_generator.models import ParentModel def type_converter( # noqa: C901 @@ -118,16 +118,14 @@ def type_converter( # noqa: C901 *[i.import_types for i in conversions if i.import_types is not None] ) ) - # We only want to auto convert to datetime if orjson is used throghout the code, otherwise we can not - # serialize it to JSON. - elif schema.type == "string" and ( - schema.schema_format is None or not common.get_use_orjson() - ): - converted_type = pre_type + "str" + post_type + # With custom string format fields, in order to cast these to strict types (e.g. date, datetime, UUID) + # orjson is required for JSON serialiation. elif ( - schema.type == "string" - and schema.schema_format.startswith("uuid") - and common.get_use_orjson() + schema.type == "string" + and schema.schema_format is not None + and schema.schema_format.startswith("uuid") + # orjson and pydantic v2 both support UUID + and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2) ): if len(schema.schema_format) > 4 and schema.schema_format[4].isnumeric(): uuid_type = schema.schema_format.upper() @@ -136,9 +134,39 @@ def type_converter( # noqa: C901 else: converted_type = pre_type + "UUID" + post_type import_types = ["from uuid import UUID"] - elif schema.type == "string" and schema.schema_format == "date-time": - converted_type = pre_type + "datetime" + post_type - import_types = ["from datetime import datetime"] + elif ( + schema.type == "string" + and schema.schema_format == "date-time" + # orjson and pydantic v2 both support datetime + and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2) + ): + if common.get_pydantic_use_awaredatetime(): + converted_type = pre_type + "AwareDatetime" + post_type + import_types = ["from pydantic import AwareDatetime"] + else: + converted_type = pre_type + "datetime" + post_type + import_types = ["from datetime import datetime"] + elif ( + schema.type == "string" + and schema.schema_format == "date" + # orjson and pydantic v2 both support date + and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2) + ): + converted_type = pre_type + "date" + post_type + import_types = ["from datetime import date"] + elif ( + schema.type == "string" + and schema.schema_format == "decimal" + # orjson does not support Decimal + # See https://github.com/ijl/orjson/issues/444 + and not common.get_use_orjson() + # pydantic v2 supports Decimal + and common.get_pydantic_version() == PydanticVersion.V2 + ): + converted_type = pre_type + "Decimal" + post_type + import_types = ["from decimal import Decimal"] + elif schema.type == "string": + converted_type = pre_type + "str" + post_type elif schema.type == "integer": converted_type = pre_type + "int" + post_type elif schema.type == "number": @@ -157,7 +185,9 @@ def type_converter( # noqa: C901 elif isinstance(schema.items, Schema): original_type = "array<" + ( str(schema.items.type.value) if schema.items.type is not None else "unknown") + ">" - retVal += type_converter(schema.items, True).converted_type + items_type = type_converter(schema.items, True) + import_types = items_type.import_types + retVal += items_type.converted_type else: original_type = "array" retVal += "Any" @@ -257,6 +287,32 @@ def _generate_property_from_reference( import_type=[import_model], ) +def _generate_property( + model_name: str, + name: str, + schema_or_reference: Schema | Reference, + parent_schema: Optional[Schema] = None, +) -> Property: + if isinstance(schema_or_reference, Reference): + return _generate_property_from_reference( + model_name, name, schema_or_reference, parent_schema + ) + + return _generate_property_from_schema( + model_name, name, schema_or_reference, parent_schema + ) + +def _collect_properties_from_schema(model_name: str, parent_schema: Schema): + property_iterator = ( + parent_schema.properties.items() + if parent_schema.properties is not None + else {} + ) + for name, schema_or_reference in property_iterator: + conv_property = _generate_property( + model_name, name, schema_or_reference, parent_schema + ) + yield conv_property def generate_models(components: Components, pydantic_version: PydanticVersion = PydanticVersion.V2) -> List[Model]: """ @@ -277,11 +333,9 @@ def generate_models(components: Components, pydantic_version: PydanticVersion = for schema_name, schema_or_reference in components.schemas.items(): name = common.normalize_symbol(schema_name) if schema_or_reference.enum is not None: - value_dict = schema_or_reference.dict() - regex = re.compile(r"[\s\/=\*\+]+") + value_dict = schema_or_reference.model_dump() value_dict["enum"] = [ - re.sub(regex, "_", i) if isinstance(i, str) else f"value_{i}" - for i in value_dict["enum"] + (common.normalize_symbol(str(i)).upper(), i) for i in value_dict["enum"] ] m = Model( file_name=name, @@ -299,27 +353,39 @@ def generate_models(components: Components, pydantic_version: PydanticVersion = continue # pragma: no cover + # Enumerate properties for this model properties = [] - property_iterator = ( - schema_or_reference.properties.items() - if schema_or_reference.properties is not None - else {} - ) - for prop_name, property in property_iterator: - if isinstance(property, Reference): - conv_property = _generate_property_from_reference( - name, prop_name, property, schema_or_reference - ) - else: - conv_property = _generate_property_from_schema( - name, prop_name, property, schema_or_reference - ) + for conv_property in _collect_properties_from_schema(name, schema_or_reference): properties.append(conv_property) + # Enumerate union types that compose this model (if any) from allOf, oneOf, anyOf + parent_components = [] + components_iterator = ( + (schema_or_reference.allOf or []) + (schema_or_reference.oneOf or []) + (schema_or_reference.anyOf or []) + ) + for parent_component in components_iterator: + # For references, instead of importing properties, record inherited components + if isinstance(parent_component, Reference): + ref = parent_component.ref + parent_name = common.normalize_symbol(ref.split("/")[-1]) + parent_components.append(ParentModel( + ref = ref, + name = parent_name, + import_type = f"from .{parent_name} import {parent_name}" + )) + + # Collect inline properties + if isinstance(parent_component, Schema): + for conv_property in _collect_properties_from_schema(name, parent_component): + properties.append(conv_property) + template_name = MODELS_TEMPLATE_PYDANTIC_V2 if pydantic_version == PydanticVersion.V2 else MODELS_TEMPLATE generated_content = jinja_env.get_template(template_name).render( - schema_name=name, schema=schema_or_reference, properties=properties + schema_name=name, + schema=schema_or_reference, + properties=properties, + parent_components=parent_components ) try: @@ -333,6 +399,7 @@ def generate_models(components: Components, pydantic_version: PydanticVersion = content=generated_content, openapi_object=schema_or_reference, properties=properties, + parent_components=parent_components ) ) diff --git a/src/openapi_python_generator/language_converters/python/service_generator.py b/src/openapi_python_generator/language_converters/python/service_generator.py index 582b390..1ec8f19 100644 --- a/src/openapi_python_generator/language_converters/python/service_generator.py +++ b/src/openapi_python_generator/language_converters/python/service_generator.py @@ -26,13 +26,29 @@ HTTP_OPERATIONS = ["get", "post", "put", "delete", "options", "head", "patch", "trace"] +def _generate_body_dump_expression(data = "data") -> str: + """ + Generate expression for dumping abstract body as a dictionary. + """ + + # Use old v1 method for pydantic + if common.get_pydantic_version() == common.PydanticVersion.V1: + return f"{data}.dict()" + + # Dump model but allow orjson to serialise (fastest) + if common.get_use_orjson(): + return f"{data}.model_dump()" + + # rely on pydantic v2 to serialise (slowest, but best compatibility) + return f"{data}.model_dump(mode=\"json\")" + def generate_body_param(operation: Operation) -> Union[str, None]: if operation.requestBody is None: return None else: if isinstance(operation.requestBody, Reference): - return "data.dict()" + return _generate_body_dump_expression("data") if operation.requestBody.content is None: return None # pragma: no cover @@ -46,11 +62,12 @@ def generate_body_param(operation: Operation) -> Union[str, None]: return None # pragma: no cover if isinstance(media_type.media_type_schema, Reference): - return "data.dict()" + return _generate_body_dump_expression("data") elif isinstance(media_type.media_type_schema, Schema): schema = media_type.media_type_schema if schema.type == "array": - return "[i.dict() for i in data]" + expression = _generate_body_dump_expression("i") + return f"[{expression} for i in data]" elif schema.type == "object": return "data" else: diff --git a/src/openapi_python_generator/language_converters/python/templates/aiohttp.jinja2 b/src/openapi_python_generator/language_converters/python/templates/aiohttp.jinja2 index 6f2eb16..9a18ba2 100644 --- a/src/openapi_python_generator/language_converters/python/templates/aiohttp.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/aiohttp.jinja2 @@ -25,9 +25,9 @@ async def {{ operation_id }}({{ params }} api_config_override : Optional[APIConf params=query_params, {% if body_param %} {% if use_orjson %} - data=orjson.dumps({{ body_param }}) + data=orjson.dumps({{ body_param | safe }}) {% else %} - json = {{ body_param }} + json = {{ body_param | safe}} {% endif %} {% endif %} ) as inital_response: diff --git a/src/openapi_python_generator/language_converters/python/templates/enum.jinja2 b/src/openapi_python_generator/language_converters/python/templates/enum.jinja2 index f88bdd6..848f159 100644 --- a/src/openapi_python_generator/language_converters/python/templates/enum.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/enum.jinja2 @@ -1,9 +1,11 @@ from enum import Enum class {{ name }}(str, Enum): - {% for enumItem in enum %} + {% for enumItemIdent, enumItem in enum %} {% if enumItem is string %} - {{ enumItem.upper() }} = '{{ enumItem }}'{% else %} - value_{{ enumItem }} = {{ enumItem }}{% endif %} + {{ enumItemIdent }} = '{{ enumItem }}' + {% else %} + {{ enumItemIdent }} = {{ enumItem }} + {% endif %} {% endfor %} diff --git a/src/openapi_python_generator/language_converters/python/templates/httpx.jinja2 b/src/openapi_python_generator/language_converters/python/templates/httpx.jinja2 index 02c2bd1..044ddef 100644 --- a/src/openapi_python_generator/language_converters/python/templates/httpx.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/httpx.jinja2 @@ -30,9 +30,9 @@ with httpx.Client(base_url=base_path, verify=api_config.verify) as client: params=query_params, {% if body_param %} {% if use_orjson %} - content=orjson.dumps({{ body_param }}) + content=orjson.dumps({{ body_param | safe }}) {% else %} - json = {{ body_param }} + json = {{ body_param | safe }} {% endif %} {% endif %} ) diff --git a/src/openapi_python_generator/language_converters/python/templates/models.jinja2 b/src/openapi_python_generator/language_converters/python/templates/models.jinja2 index e2a90aa..dbafa29 100644 --- a/src/openapi_python_generator/language_converters/python/templates/models.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/models.jinja2 @@ -7,16 +7,20 @@ from pydantic import BaseModel, Field {% endfor %} {% endif %} {% endfor %} +{% for parent_component in parent_components %} +{% if parent_component.import_type is not none %} +{{ parent_component.import_type }} +{% endif %} +{% endfor %} -class {{ schema_name }}(BaseModel): +class {{ schema_name }}({% for parent_component in parent_components %}{{ parent_component.name }},{% endfor %}BaseModel): """ {{ schema.title }} model {% if schema.description != None %} {{ schema.description }} {% endif %} - """ {% for property in properties %} - {{ property.name | replace("@","") | replace("-","_") }} : {{ property.type.converted_type | safe }} = Field(alias="{{ property.name }}" {% if not property.required %}, default = {{ property.default }} {% endif %}) + {{ property.name | normalize_symbol }} : {{ property.type.converted_type | safe }} = Field(alias="{{ property.name }}" {% if not property.required %}, default = {{ property.default }} {% endif %}) {% endfor %} diff --git a/src/openapi_python_generator/language_converters/python/templates/models_pydantic_2.jinja2 b/src/openapi_python_generator/language_converters/python/templates/models_pydantic_2.jinja2 index 7d4cfbd..0e37261 100644 --- a/src/openapi_python_generator/language_converters/python/templates/models_pydantic_2.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/models_pydantic_2.jinja2 @@ -1,5 +1,5 @@ from typing import * -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict {% for property in properties %} {% if property.type.import_types is not none %} {% for import_type in property.type.import_types %} @@ -7,20 +7,28 @@ from pydantic import BaseModel, Field {% endfor %} {% endif %} {% endfor %} +{% for parent_component in parent_components %} +{% if parent_component.import_type is not none %} +{{ parent_component.import_type }} +{% endif %} +{% endfor %} -class {{ schema_name }}(BaseModel): +class {{ schema_name }}({% for parent_component in parent_components %}{{ parent_component.name }},{% endfor %}BaseModel): """ {{ schema.title }} model {% if schema.description != None %} {{ schema.description }} {% endif %} - """ - model_config = { - "populate_by_name": True, - "validate_assignment": True - } + model_config = ConfigDict( + populate_by_name= True, + validate_assignment=True, + from_attributes=True, + {% if schema.additionalProperties %} + extra="allow", + {% endif %} + ) {% for property in properties %} - {{ property.name | replace("@","") | replace("-","_") }} : {{ property.type.converted_type | safe }} = Field(validation_alias="{{ property.name }}" {% if not property.required %}, default = {{ property.default }} {% endif %}) + {{ property.name | normalize_symbol }} : {{ property.type.converted_type | safe }} = Field(validation_alias="{{ property.name }}" {% if not property.required %}, default = {{ property.default }} {% endif %}) {% endfor %} \ No newline at end of file diff --git a/src/openapi_python_generator/language_converters/python/templates/requests.jinja2 b/src/openapi_python_generator/language_converters/python/templates/requests.jinja2 index 3aaa5de..4e6e3cd 100644 --- a/src/openapi_python_generator/language_converters/python/templates/requests.jinja2 +++ b/src/openapi_python_generator/language_converters/python/templates/requests.jinja2 @@ -25,9 +25,9 @@ def {{ operation_id }}({{ params }} api_config_override : Optional[APIConfig] = verify=api_config.verify, {% if body_param %} {% if use_orjson %} - content=orjson.dumps({{ body_param }}) + content=orjson.dumps({{ body_param | safe }}) {% else %} - json = {{ body_param }} + json = {{ body_param | safe }} {% endif %} {% endif %} ) diff --git a/src/openapi_python_generator/models.py b/src/openapi_python_generator/models.py index 60e0eb5..cf2ad1f 100644 --- a/src/openapi_python_generator/models.py +++ b/src/openapi_python_generator/models.py @@ -51,11 +51,18 @@ class Property(BaseModel): import_type: Optional[List[str]] = None +class ParentModel(BaseModel): + ref: str + name: str + import_type: Optional[str] = None + + class Model(BaseModel): file_name: str content: str openapi_object: Schema properties: List[Property] = [] + parent_components: List[ParentModel] = [] class Service(BaseModel): diff --git a/tests/build_test_api/api.py b/tests/build_test_api/api.py index 4c86648..d1d2515 100644 --- a/tests/build_test_api/api.py +++ b/tests/build_test_api/api.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, date from typing import List from typing import Optional @@ -20,17 +20,19 @@ class User(BaseModel): username: str email: str password: str - is_active: Optional[bool] + is_active: Optional[bool] = None + created_at: Optional[datetime] = None + birthdate: Optional[date] = None class Team(BaseModel): id: int name: str description: str - is_active: Optional[bool] - created_at: Optional[datetime] - updated_at: Optional[datetime] - users: Optional[List[User]] + is_active: Optional[bool] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + users: Optional[List[User]] = None @app.get("/", response_model=RootResponse, tags=["general"]) diff --git a/tests/conftest.py b/tests/conftest.py index 5ff4508..10f9383 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict from typing import Generator +from openapi_python_generator.language_converters.python import common import pytest from openapi_pydantic.v3.v3_0 import OpenAPI @@ -32,3 +33,52 @@ def model_data_with_cleanup_fixture(model_data) -> OpenAPI: # type: ignore if test_result_path.exists(): # delete folder and all subfolders shutil.rmtree(test_result_path) + + +@pytest.fixture +def with_orjson_enabled(): + """ + Fixture to enable orjson for the duration of the test + """ + orjson_usage = common.get_use_orjson() + common.set_use_orjson(True) + try: + yield + finally: + common.set_use_orjson(orjson_usage) + +@pytest.fixture +def with_orjson_disabled(): + """ + Fixture to enable orjson for the duration of the test + """ + orjson_usage = common.get_use_orjson() + common.set_use_orjson(False) + try: + yield + finally: + common.set_use_orjson(orjson_usage) + +@pytest.fixture +def with_pydantic_v1(): + """ + Fixture to set pydantic to v1 for the duration of the test + """ + pydantic_version = common.get_pydantic_version() + common.set_pydantic_version(common.PydanticVersion.V1) + try: + yield + finally: + common.set_pydantic_version(pydantic_version) + +@pytest.fixture +def with_pydantic_v2(): + """ + Fixture to set pydantic to v2 for the duration of the test + """ + pydantic_version = common.get_pydantic_version() + common.set_pydantic_version(common.PydanticVersion.V2) + try: + yield + finally: + common.set_pydantic_version(pydantic_version) diff --git a/tests/regression/test_issue_30_87.py b/tests/regression/test_issue_30_87.py new file mode 100644 index 0000000..7456f2f --- /dev/null +++ b/tests/regression/test_issue_30_87.py @@ -0,0 +1,24 @@ +import pytest + +from openapi_python_generator.common import HTTPLibrary +from openapi_python_generator.generate_data import get_open_api +from openapi_python_generator.parsers import generate_code_3_1 +from tests.conftest import test_data_folder + + +@pytest.mark.parametrize( + "library", + [HTTPLibrary.httpx, HTTPLibrary.aiohttp, HTTPLibrary.requests], +) +def test_issue_30_87(library) -> None: + """ + https://github.com/MarcoMuellner/openapi-python-generator/issues/30 + https://github.com/MarcoMuellner/openapi-python-generator/issues/87 + """ + openapi_obj, version = get_open_api(str(test_data_folder / "issue_30_87.json")) + result = generate_code_3_1(openapi_obj, library) # type: ignore + + expected_model = [m for m in result.models if m.openapi_object.title == "UserType"][ + 0 + ] + assert "ADMIN_USER = 'admin-user'" in expected_model.content diff --git a/tests/regression/test_issue_55.py b/tests/regression/test_issue_55.py new file mode 100644 index 0000000..572d610 --- /dev/null +++ b/tests/regression/test_issue_55.py @@ -0,0 +1,23 @@ +import pytest + +from openapi_python_generator.common import HTTPLibrary +from openapi_python_generator.generate_data import get_open_api +from openapi_python_generator.parsers import generate_code_3_1 +from tests.conftest import test_data_folder + + +@pytest.mark.parametrize( + "library", + [HTTPLibrary.httpx, HTTPLibrary.aiohttp, HTTPLibrary.requests], +) +def test_issue_55(library) -> None: + """ + https://github.com/MarcoMuellner/openapi-python-generator/issues/55 + """ + openapi_obj, version = get_open_api(str(test_data_folder / "issue_55.json")) + result = generate_code_3_1(openapi_obj, library) # type: ignore + + expected_model = [m for m in result.models if m.openapi_object.title == "UserType"][ + 0 + ] + assert "ADMIN_USER = 'admin user'" in expected_model.content diff --git a/tests/test_data/issue_30_87.json b/tests/test_data/issue_30_87.json new file mode 100644 index 0000000..4f4f15a --- /dev/null +++ b/tests/test_data/issue_30_87.json @@ -0,0 +1,70 @@ +{ + "openapi": "3.0.2", + "info": { + "title": "Title", + "version": "1.0" + }, + "paths": { + "/users": { + "get": { + "summary": "Get users", + "description": "Returns a list of users.", + "operationId": "users_get", + "parameters": [ + { + "name": "type", + "in": "query", + "required": true, + "schema": { + "$ref": "#/components/schemas/UserType" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "UserType": { + "title": "UserType", + "description": "An enumeration.", + "enum": ["admin-user", "regular-user"] + }, + "User": { + "title": "User", + "description": "A user.", + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + }, + "name": { + "type": "string" + }, + "type": { + "$ref": "#/components/schemas/UserType" + }, + "30d_active": { + "type": "boolean" + } + } + } + } + } +} diff --git a/tests/test_data/issue_55.json b/tests/test_data/issue_55.json new file mode 100644 index 0000000..3de3259 --- /dev/null +++ b/tests/test_data/issue_55.json @@ -0,0 +1,70 @@ +{ + "openapi": "3.0.2", + "info": { + "title": "Title", + "version": "1.0" + }, + "paths": { + "/users": { + "get": { + "summary": "Get users", + "description": "Returns a list of users.", + "operationId": "users_get", + "parameters": [ + { + "name": "type", + "in": "query", + "required": true, + "schema": { + "$ref": "#/components/schemas/UserType" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "UserType": { + "title": "UserType", + "description": "An enumeration.", + "enum": ["admin user", "regular user"] + }, + "User": { + "title": "User", + "description": "A user.", + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + }, + "name": { + "type": "string" + }, + "type": { + "$ref": "#/components/schemas/UserType" + }, + "30d_active": { + "type": "boolean" + } + } + } + } + } +} diff --git a/tests/test_data/test_api.json b/tests/test_data/test_api.json index 85bb7ce..03c0c16 100644 --- a/tests/test_data/test_api.json +++ b/tests/test_data/test_api.json @@ -527,6 +527,11 @@ "title": "Created At", "type": "string", "format": "date-time" + }, + "birthdate": { + "title": "Birthdate", + "type": "string", + "format": "date" } } }, @@ -573,6 +578,27 @@ } } }, + "Admin": { + "title": "Admin", + "allOf": [ + { + "$ref": "#/components/schemas/User" + }, + { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "Role name for this admin user" + }, + "group": { + "type": "integer", + "description": "Admin group ID" + } + } + } + ] + }, "EnumComponent": { "title": "EnumComponent", "enum": [ diff --git a/tests/test_generated_code.py b/tests/test_generated_code.py index 83c276b..1a55b78 100644 --- a/tests/test_generated_code.py +++ b/tests/test_generated_code.py @@ -64,7 +64,7 @@ def test_set_auth_token(): ], ) @respx.mock -def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip): +def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip, with_pydantic_v2): generate_data(test_data_path, test_result_path, library, use_orjson=use_orjson) result = generator(model_data_with_cleanup, library_config_dict[library]) @@ -356,8 +356,8 @@ def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip): name="team1", description="team1", is_active=True, - created_at="", - updated_at="", + created_at=None, + updated_at=None, ) exec_code_base = f"from .test_result.services.general_service import *\nfrom datetime import datetime\nresp_result = create_team_teams_post(Team(**{data}), passed_api_config)" diff --git a/tests/test_model_generator.py b/tests/test_model_generator.py index d8917e3..66e3652 100644 --- a/tests/test_model_generator.py +++ b/tests/test_model_generator.py @@ -43,6 +43,14 @@ Schema(type=DataType.STRING, schema_format="date-time"), TypeConversion(original_type="string", converted_type="str"), ), + ( + Schema(type=DataType.STRING, schema_format="date"), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.STRING, schema_format="decimal"), + TypeConversion(original_type="string", converted_type="str"), + ), ( Schema(type=DataType.OBJECT), TypeConversion(original_type="object", converted_type="Dict[str, Any]"), @@ -89,7 +97,142 @@ ), ], ) -def test_type_converter_simple(test_openapi_types, expected_python_types): +def test_type_converter_pydanticv1(test_openapi_types, expected_python_types, with_orjson_disabled, with_pydantic_v1): + """ + Test base case with pydantic v1 and orjson disabled + """ + assert type_converter(test_openapi_types, True) == expected_python_types + + if test_openapi_types.type == "array" and isinstance( + test_openapi_types.items, Reference + ): + expected_type = expected_python_types.converted_type.split("[")[-1].split("]")[ + 0 + ] + + assert ( + type_converter(test_openapi_types, False).converted_type + == "Optional[List[Optional[" + expected_type + "]]]" + ) + else: + assert ( + type_converter(test_openapi_types, False).converted_type + == "Optional[" + expected_python_types.converted_type + "]" + ) + +@pytest.mark.parametrize( + "test_openapi_types,expected_python_types", + [ + ( + Schema(type=DataType.STRING), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.INTEGER), + TypeConversion(original_type="integer", converted_type="int"), + ), + ( + Schema(type=DataType.NUMBER), + TypeConversion(original_type="number", converted_type="float"), + ), + ( + Schema(type=DataType.BOOLEAN), + TypeConversion(original_type="boolean", converted_type="bool"), + ), + ( + Schema(type=DataType.STRING, schema_format="date-time"), + TypeConversion( + original_type="string", + converted_type="datetime", + import_types=["from datetime import datetime"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="date"), + TypeConversion( + original_type="string", + converted_type="date", + import_types=["from datetime import date"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="decimal"), + TypeConversion( + original_type="string", + converted_type="Decimal", + import_types=["from decimal import Decimal"], + ), + ), + ( + Schema(type=DataType.OBJECT), + TypeConversion(original_type="object", converted_type="Dict[str, Any]"), + ), + ( + Schema(type=DataType.ARRAY), + TypeConversion(original_type="array", converted_type="List[Any]"), + ), + ( + Schema(type=DataType.ARRAY, items=Schema(type=DataType.STRING)), + TypeConversion(original_type="array", converted_type="List[str]"), + ), + ( + Schema(type=DataType.ARRAY, items=Reference(ref="#/components/schemas/test_name")), + TypeConversion( + original_type="array<#/components/schemas/test_name>", + converted_type="List[test_name]", + import_types=["from .test_name import test_name"], + ), + ), + ( + Schema(type=None), + TypeConversion(original_type="object", converted_type="Any"), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid"), + TypeConversion( + original_type="string", + converted_type="UUID", + import_types=["from uuid import UUID"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid1"), + TypeConversion( + original_type="string", + converted_type="UUID1", + import_types=["from pydantic import UUID1"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid3"), + TypeConversion( + original_type="string", + converted_type="UUID3", + import_types=["from pydantic import UUID3"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid4"), + TypeConversion( + original_type="string", + converted_type="UUID4", + import_types=["from pydantic import UUID4"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid5"), + TypeConversion( + original_type="string", + converted_type="UUID5", + import_types=["from pydantic import UUID5"], + ), + ), + ], +) +def test_type_converter_pydanticv2(test_openapi_types, expected_python_types, with_orjson_disabled, with_pydantic_v2): + """ + Test base case with pydantic v2 and orjson disabled + """ assert type_converter(test_openapi_types, True) == expected_python_types if test_openapi_types.type == "array" and isinstance( @@ -137,6 +280,25 @@ def test_type_converter_simple(test_openapi_types, expected_python_types): import_types=["from datetime import datetime"], ), ), + ( + Schema(type=DataType.STRING, schema_format="date"), + TypeConversion( + original_type="string", + converted_type="date", + import_types=["from datetime import date"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="decimal"), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.STRING, schema_format="email"), + TypeConversion( + original_type="string", + converted_type="str", + ), + ), ( Schema(type=DataType.OBJECT), TypeConversion(original_type="object", converted_type="Dict[str, Any]"), @@ -203,9 +365,10 @@ def test_type_converter_simple(test_openapi_types, expected_python_types): ), ], ) -def test_type_converter_simple_orjson(test_openapi_types, expected_python_types): - orjson_usage = common.get_use_orjson() - common.set_use_orjson(True) +def test_type_converter_orjson_pydanticv1(test_openapi_types, expected_python_types, with_orjson_enabled, with_pydantic_v1): + """ + Test type conversion with pydantic v1 and orjson enabled + """ assert type_converter(test_openapi_types, True) == expected_python_types if test_openapi_types.type == "array" and isinstance( test_openapi_types.items, Reference @@ -223,8 +386,137 @@ def test_type_converter_simple_orjson(test_openapi_types, expected_python_types) type_converter(test_openapi_types, False).converted_type == "Optional[" + expected_python_types.converted_type + "]" ) - common.set_use_orjson(orjson_usage) +@pytest.mark.parametrize( + "test_openapi_types,expected_python_types", + [ + ( + Schema(type=DataType.STRING), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.INTEGER), + TypeConversion(original_type="integer", converted_type="int"), + ), + ( + Schema(type=DataType.NUMBER), + TypeConversion(original_type="number", converted_type="float"), + ), + ( + Schema(type=DataType.BOOLEAN), + TypeConversion(original_type="boolean", converted_type="bool"), + ), + ( + Schema(type=DataType.STRING, schema_format="date-time"), + TypeConversion( + original_type="string", + converted_type="datetime", + import_types=["from datetime import datetime"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="date"), + TypeConversion( + original_type="string", + converted_type="date", + import_types=["from datetime import date"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="decimal"), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.STRING, schema_format="email"), + TypeConversion(original_type="string", converted_type="str"), + ), + ( + Schema(type=DataType.OBJECT), + TypeConversion(original_type="object", converted_type="Dict[str, Any]"), + ), + ( + Schema(type=DataType.ARRAY), + TypeConversion(original_type="array", converted_type="List[Any]"), + ), + ( + Schema(type=DataType.ARRAY, items=Schema(type=DataType.STRING)), + TypeConversion(original_type="array", converted_type="List[str]"), + ), + ( + Schema(type=DataType.ARRAY, items=Reference(ref="#/components/schemas/test_name")), + TypeConversion( + original_type="array<#/components/schemas/test_name>", + converted_type="List[test_name]", + import_types=["from .test_name import test_name"], + ), + ), + ( + Schema(type=None), + TypeConversion(original_type="object", converted_type="Any"), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid"), + TypeConversion( + original_type="string", + converted_type="UUID", + import_types=["from uuid import UUID"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid1"), + TypeConversion( + original_type="string", + converted_type="UUID1", + import_types=["from pydantic import UUID1"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid3"), + TypeConversion( + original_type="string", + converted_type="UUID3", + import_types=["from pydantic import UUID3"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid4"), + TypeConversion( + original_type="string", + converted_type="UUID4", + import_types=["from pydantic import UUID4"], + ), + ), + ( + Schema(type=DataType.STRING, schema_format="uuid5"), + TypeConversion( + original_type="string", + converted_type="UUID5", + import_types=["from pydantic import UUID5"], + ), + ), + ], +) +def test_type_converter_orjson_pydanticv2(test_openapi_types, expected_python_types, with_orjson_enabled, with_pydantic_v2): + """ + Test type conversion with pydantic v2 and orjson enabled + """ + assert type_converter(test_openapi_types, True) == expected_python_types + if test_openapi_types.type == "array" and isinstance( + test_openapi_types.items, Reference + ): + expected_type = expected_python_types.converted_type.split("[")[-1].split("]")[ + 0 + ] + + assert ( + type_converter(test_openapi_types, False).converted_type + == "Optional[List[Optional[" + expected_type + "]]]" + ) + else: + assert ( + type_converter(test_openapi_types, False).converted_type + == "Optional[" + expected_python_types.converted_type + "]" + ) def test_type_converter_all_of_reference(): schema = Schema( diff --git a/tests/test_service_generator.py b/tests/test_service_generator.py index 1936c7d..27aaf5a 100644 --- a/tests/test_service_generator.py +++ b/tests/test_service_generator.py @@ -68,10 +68,105 @@ (Operation(responses=default_responses, requestBody=None), None), ], ) -def test_generate_body_param(test_openapi_operation, expected_result): +def test_generate_body_param_pydanticv1(test_openapi_operation, expected_result, with_orjson_disabled, with_pydantic_v1): assert generate_body_param(test_openapi_operation) == expected_result +@pytest.mark.parametrize( + "test_openapi_operation, expected_result", + [ + ( + Operation( + responses=default_responses, + requestBody=RequestBody( + content={ + "application/json": MediaType( + media_type_schema=Reference( + ref="#/components/schemas/TestModel" + ) + ) + } + ) + ), + "data.model_dump(mode=\"json\")", + ), + ( + Operation( + responses=default_responses, + requestBody=Reference(ref="#/components/schemas/TestModel") + ), + "data.model_dump(mode=\"json\")", + ), + ( + Operation( + responses=default_responses, + requestBody=RequestBody( + content={ + "application/json": MediaType( + media_type_schema=Schema( + type=DataType.ARRAY, + items=Reference(ref="#/components/schemas/TestModel"), + ) + ) + } + ) + ), + "[i.model_dump(mode=\"json\") for i in data]", + ), + (Operation(responses=default_responses, requestBody=None), None), + ], +) +def test_generate_body_param_pydanticv2(test_openapi_operation, expected_result, with_orjson_disabled, with_pydantic_v2): + assert generate_body_param(test_openapi_operation) == expected_result + + +@pytest.mark.parametrize( + "test_openapi_operation, expected_result", + [ + ( + Operation( + responses=default_responses, + requestBody=RequestBody( + content={ + "application/json": MediaType( + media_type_schema=Reference( + ref="#/components/schemas/TestModel" + ) + ) + } + ) + ), + "data.model_dump()", + ), + ( + Operation( + responses=default_responses, + requestBody=Reference(ref="#/components/schemas/TestModel") + ), + "data.model_dump()", + ), + ( + Operation( + responses=default_responses, + requestBody=RequestBody( + content={ + "application/json": MediaType( + media_type_schema=Schema( + type=DataType.ARRAY, + items=Reference(ref="#/components/schemas/TestModel"), + ) + ) + } + ) + ), + "[i.model_dump() for i in data]", + ), + (Operation(responses=default_responses, requestBody=None), None), + ], +) +def test_generate_body_param_orjson_pydanticv2(test_openapi_operation, expected_result, with_orjson_enabled, with_pydantic_v2): + assert generate_body_param(test_openapi_operation) == expected_result + @pytest.mark.parametrize( "test_openapi_operation, expected_result", [