Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/test_dependency_wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ async def get_async_wrapped_gen_dependency(
@pytest.mark.parametrize(
"route",
[
"/wrapped-dependency",
"/wrapped-gen-dependency",
"/async-wrapped-dependency",
"/async-wrapped-gen-dependency",
"/wrapped-dependency/",
"/wrapped-gen-dependency/",
"/async-wrapped-dependency/",
"/async-wrapped-gen-dependency/",
],
)
def test_class_dependency(route):
Expand Down
122 changes: 122 additions & 0 deletions tests/test_functools_wraps_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Tests for additional functools.wraps scenarios beyond the combined partial cases.

This test file adds coverage for edge cases with multiple wrapping layers,
async callable classes, and various decorator orderings.
"""
from functools import wraps, partial
from typing import AsyncGenerator

from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient


app = FastAPI()


# Sync wrapper decorator that just passes through
def sync_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


# Async wrapper decorator
def async_wrapper(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# Await if result is a coroutine
if hasattr(result, '__await__'):
result = await result
return result

return wrapper


# Test 1: Multiple wrapping layers - async wrapper on sync function
@async_wrapper
def multi_wrapped() -> str:
return "multi"


@app.get("/multi-wrapped/")
async def endpoint_multi(value: str = Depends(multi_wrapped)):
return value


# Test 2: Async callable class with sync wrapper then partial
class AsyncCallable:
async def __call__(self, value: str) -> str:
return value


async_callable_inst = AsyncCallable()
wrapped_async_callable = sync_wrapper(async_callable_inst)
partial_wrapped_async_callable = partial(wrapped_async_callable, "async_callable")


@app.get("/wrapped-async-callable/")
async def endpoint_async_callable(value: str = Depends(partial_wrapped_async_callable)):
return value


# Test 3: Async generator callable class with wrapper and partial
class AsyncGenCallable:
async def __call__(self, value: str) -> AsyncGenerator:
yield value


async_gen_callable_inst = AsyncGenCallable()
wrapped_async_gen_callable = sync_wrapper(async_gen_callable_inst)
partial_wrapped_async_gen_callable = partial(wrapped_async_gen_callable, "async_gen_callable")


@app.get("/wrapped-async-gen-callable/")
async def endpoint_async_gen_callable(value: str = Depends(partial_wrapped_async_gen_callable)):
return value


# Test 4: Endpoint using partial wrapped in async decorator
def endpoint_func():
return "partial_endpoint"


partial_endpoint = partial(endpoint_func)
wrapped_partial_endpoint = async_wrapper(partial_endpoint)


app.get("/wrapped-partial-endpoint/")(wrapped_partial_endpoint)


client = TestClient(app)


def test_multi_wrapped():
"""Test multiple wrapping layers."""
response = client.get("/multi-wrapped/")
assert response.status_code == 200, response.text
assert response.json() == "multi"


def test_wrapped_async_callable():
"""Test async callable class with sync wrapper then partial."""
response = client.get("/wrapped-async-callable/")
assert response.status_code == 200, response.text
assert response.json() == "async_callable"


def test_wrapped_async_gen_callable():
"""Test async generator callable class with wrapper and partial."""
response = client.get("/wrapped-async-gen-callable/")
assert response.status_code == 200, response.text
assert response.json() == "async_gen_callable"


def test_wrapped_partial_endpoint():
"""Test endpoint using partial wrapped in async decorator."""
response = client.get("/wrapped-partial-endpoint/")
assert response.status_code == 200, response.text
assert response.json() == "partial_endpoint"
207 changes: 207 additions & 0 deletions tests/test_wrapped_partial_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Tests for functools.wraps and functools.partial combined.

This test file verifies that FastAPI correctly detects async/sync functions
when functools.wraps and functools.partial are combined in various ways.
These tests should fail on the base commit (before fix) and pass with the fix.
"""
from functools import wraps, partial
from typing import AsyncGenerator, Generator

from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient


app = FastAPI()


# Decorator that creates an async wrapper
def async_wrapper_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if hasattr(result, '__await__'):
result = await result
return result

return wrapper


# Test 1: Exact issue from problem statement - sync function with async wrapper
def my_decorator(func):
@wraps(func)
async def wrapper():
func()
return "OK"

return wrapper


@app.get("/issue-example/")
@my_decorator
def index():
print("Hello!")


# Test 2: Partial wrapped with async decorator (dependencies)
def base_func(value: str) -> str:
return value


partial_func_async_wrapped = async_wrapper_decorator(partial(base_func, "test2"))


@app.get("/partial-async-wrapped-dep/")
async def endpoint2(value: str = Depends(partial_func_async_wrapped)):
return value


# Test 3: Wrapped function then partial (dependencies)
wrapped_func = async_wrapper_decorator(base_func)
partial_wrapped_func = partial(wrapped_func, "test3")


@app.get("/wrapped-then-partial-dep/")
async def endpoint3(value: str = Depends(partial_wrapped_func)):
return value


# Test 4: Async function with sync wrapper
async def async_base_func(value: str) -> str:
return value


def sync_wrapper_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


wrapped_async = sync_wrapper_decorator(async_base_func)


@app.get("/async-sync-wrapped-dep/")
async def endpoint4(value: str = Depends(partial(wrapped_async, "test4"))):
return value


# Test 5: Sync generator with async wrapper (wrapper converts to async generator)
def sync_gen_base(value: str) -> Generator[str, None, None]:
yield value


def async_gen_wrapper(func):
"""Wrapper that converts sync generator to async generator."""
@wraps(func)
async def wrapper(*args, **kwargs):
gen = func(*args, **kwargs)
for item in gen:
yield item
return wrapper


wrapped_sync_gen = async_gen_wrapper(sync_gen_base)
partial_wrapped_gen = partial(wrapped_sync_gen, "test5")


@app.get("/sync-gen-async-wrapped-dep/")
async def endpoint5(value: str = Depends(partial_wrapped_gen)):
return value


# Test 6: Async generator with sync wrapper then partial
async def async_gen(value: str) -> AsyncGenerator[str, None]:
yield value


wrapped_async_gen = sync_wrapper_decorator(async_gen)
partial_wrapped_async_gen = partial(wrapped_async_gen, "test6")


@app.get("/async-gen-sync-wrapped-dep/")
async def endpoint6(value: str = Depends(partial_wrapped_async_gen)):
return value


# Test 7: Endpoint (path operation) with async wrapper
@app.get("/endpoint-async-wrapped/")
@async_wrapper_decorator
def endpoint_with_wrapper():
return "test7"


# Test 8: Callable class instance with wrapped __call__ then partial
class CallableClass:
def __call__(self, value: str) -> str:
return value


callable_inst = CallableClass()
wrapped_callable = async_wrapper_decorator(callable_inst)
partial_wrapped_callable = partial(wrapped_callable, "test8")


@app.get("/wrapped-callable-partial-dep/")
async def endpoint8(value: str = Depends(partial_wrapped_callable)):
return value


client = TestClient(app)


def test_issue_example():
"""Test the exact issue from the problem statement."""
response = client.get("/issue-example/")
assert response.status_code == 200, response.text
assert response.json() == "OK"


def test_partial_async_wrapped_dep():
"""Test partial wrapped with async decorator."""
response = client.get("/partial-async-wrapped-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test2"


def test_wrapped_then_partial_dep():
"""Test wrapped function then partial."""
response = client.get("/wrapped-then-partial-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test3"


def test_async_sync_wrapped_dep():
"""Test async function with sync wrapper."""
response = client.get("/async-sync-wrapped-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test4"


def test_sync_gen_async_wrapped_dep():
"""Test sync generator with async wrapper."""
response = client.get("/sync-gen-async-wrapped-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test5"


def test_async_gen_sync_wrapped_dep():
"""Test async generator with sync wrapper then partial."""
response = client.get("/async-gen-sync-wrapped-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test6"


def test_endpoint_async_wrapped():
"""Test endpoint with async wrapper."""
response = client.get("/endpoint-async-wrapped/")
assert response.status_code == 200, response.text
assert response.json() == "test7"


def test_wrapped_callable_partial_dep():
"""Test callable class instance with wrapped __call__ then partial."""
response = client.get("/wrapped-callable-partial-dep/")
assert response.status_code == 200, response.text
assert response.json() == "test8"