Skip to content

Commit 2baeaff

Browse files
committed
feat: change to request injection on type rather than positional
1 parent 2a592d2 commit 2baeaff

File tree

6 files changed

+284
-903
lines changed

6 files changed

+284
-903
lines changed
Lines changed: 52 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -1,207 +1,7 @@
11
import inspect
2-
import types
32
import warnings
43
from collections.abc import Callable
5-
from typing import Any, TypeVar, Union, get_args, get_origin
6-
7-
8-
def accepts_single_positional_arg(func: Callable[..., Any]) -> bool:
9-
"""
10-
True if the function accepts at least one positional argument, otherwise false.
11-
12-
This function intentionally does not define behavior for `func`s that
13-
contain more than one positional argument, or any required keyword
14-
arguments without defaults.
15-
"""
16-
try:
17-
sig = inspect.signature(func)
18-
except (ValueError, TypeError):
19-
return False
20-
21-
params = dict(sig.parameters.items())
22-
23-
if len(params) == 0:
24-
# No parameters at all - can't accept single argument
25-
return False
26-
27-
# Check if ALL remaining parameters are keyword-only
28-
all_keyword_only = all(param.kind == inspect.Parameter.KEYWORD_ONLY for param in params.values())
29-
30-
if all_keyword_only:
31-
# If all params are keyword-only, check if they ALL have defaults
32-
# If they do, the function can be called with no arguments -> no argument
33-
all_have_defaults = all(param.default is not inspect.Parameter.empty for param in params.values())
34-
if all_have_defaults:
35-
return False
36-
# otherwise, undefined (doesn't accept a positional argument, and requires at least one keyword only)
37-
38-
# Check if the ONLY parameter is **kwargs (VAR_KEYWORD)
39-
# A function with only **kwargs can't accept a positional argument
40-
if len(params) == 1:
41-
only_param = next(iter(params.values()))
42-
if only_param.kind == inspect.Parameter.VAR_KEYWORD:
43-
return False # Can't pass positional argument to **kwargs
44-
45-
# Has at least one positional or variadic parameter - can accept argument
46-
# Important note: this is designed to _not_ handle the situation where
47-
# there are multiple keyword only arguments with no defaults. In those
48-
# situations it's an invalid handler function, and will error. But it's
49-
# not the responsibility of this function to check the validity of a
50-
# callback.
51-
return True
52-
53-
54-
def get_first_parameter_type(func: Callable[..., Any]) -> Any:
55-
"""
56-
Get the type annotation of the first parameter of a function.
57-
58-
Returns None if:
59-
- The function has no parameters
60-
- The first parameter has no type annotation
61-
- The signature cannot be inspected
62-
63-
Returns the actual annotation otherwise (could be a type, Any, Union, TypeVar, etc.)
64-
"""
65-
try:
66-
sig = inspect.signature(func)
67-
except (ValueError, TypeError):
68-
return None
69-
70-
params = list(sig.parameters.values())
71-
if not params:
72-
return None
73-
74-
first_param = params[0]
75-
76-
# Skip *args and **kwargs
77-
if first_param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
78-
return None
79-
80-
annotation = first_param.annotation
81-
if annotation == inspect.Parameter.empty:
82-
return None
83-
84-
return annotation
85-
86-
87-
def type_accepts_request(param_type: Any, request_type: type) -> bool:
88-
"""
89-
Check if a parameter type annotation can accept the request type.
90-
91-
Handles:
92-
- Exact type match
93-
- Union types (checks if request_type is in the Union)
94-
- TypeVars (checks if request_type matches the bound or constraints)
95-
- Generic types (basic support)
96-
- Any (always returns True)
97-
98-
Returns False for None or incompatible types.
99-
"""
100-
if param_type is None:
101-
return False
102-
103-
# Check for Any type
104-
if param_type is Any:
105-
return True
106-
107-
# Exact match
108-
if param_type == request_type:
109-
return True
110-
111-
# Handle Union types (both typing.Union and | syntax)
112-
origin = get_origin(param_type)
113-
if origin is Union or origin is types.UnionType:
114-
args = get_args(param_type)
115-
# Check if request_type is in the Union
116-
for arg in args:
117-
if arg == request_type:
118-
return True
119-
# Recursively check each union member
120-
if type_accepts_request(arg, request_type):
121-
return True
122-
return False
123-
124-
# Handle TypeVar
125-
if isinstance(param_type, TypeVar):
126-
# Check if request_type matches the bound
127-
if param_type.__bound__ is not None:
128-
if request_type == param_type.__bound__:
129-
return True
130-
# Check if request_type is a subclass of the bound
131-
try:
132-
if issubclass(request_type, param_type.__bound__):
133-
return True
134-
except TypeError:
135-
pass
136-
137-
# Check constraints
138-
if param_type.__constraints__:
139-
for constraint in param_type.__constraints__:
140-
if request_type == constraint:
141-
return True
142-
try:
143-
if issubclass(request_type, constraint):
144-
return True
145-
except TypeError:
146-
pass
147-
148-
return False
149-
150-
# For other generic types, check if request_type matches the origin
151-
if origin is not None:
152-
# Get the base generic type (e.g., list from list[str])
153-
return request_type == origin
154-
155-
return False
156-
157-
158-
def should_pass_request(func: Callable[..., Any], request_type: type) -> tuple[bool, bool]:
159-
"""
160-
Determine if a request should be passed to the function based on parameter type inspection.
161-
162-
Returns a tuple of (should_pass_request, should_deprecate):
163-
- should_pass_request: True if the request should be passed to the function
164-
- should_deprecate: True if a deprecation warning should be issued
165-
166-
The decision logic:
167-
1. If the function has no parameters -> (False, True) - old style without params, deprecate
168-
2. If the function has parameters but can't accept positional args -> (False, False)
169-
3. If the first parameter type accepts the request type -> (True, False) - pass request, no deprecation
170-
4. If the first parameter is typed as Any -> (True, True) - pass request but deprecate (effectively untyped)
171-
5. If the first parameter is typed with something incompatible -> (False, True) - old style, deprecate
172-
6. If the first parameter is untyped but accepts positional args -> (True, True) - pass request, deprecate
173-
"""
174-
can_accept_arg = accepts_single_positional_arg(func)
175-
176-
if not can_accept_arg:
177-
# Check if it has no parameters at all (old style)
178-
try:
179-
sig = inspect.signature(func)
180-
if len(sig.parameters) == 0:
181-
# Old style handler with no parameters - don't pass request but deprecate
182-
return False, True
183-
except (ValueError, TypeError):
184-
pass
185-
# Can't accept positional arguments for other reasons
186-
return False, False
187-
188-
param_type = get_first_parameter_type(func)
189-
190-
if param_type is None:
191-
# Untyped parameter - this is the old style, pass request but deprecate
192-
return True, True
193-
194-
# Check if the parameter type can accept the request
195-
if type_accepts_request(param_type, request_type):
196-
# Check if it's Any - if so, we should deprecate
197-
if param_type is Any:
198-
return True, True
199-
# Properly typed to accept the request - pass request, no deprecation
200-
return True, False
201-
202-
# Parameter is typed with something incompatible - this is an old style handler expecting
203-
# a different signature, don't pass request, issue deprecation
204-
return False, True
4+
from typing import Any, get_type_hints
2055

2066

2077
def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None:
@@ -215,3 +15,54 @@ def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> N
21515
DeprecationWarning,
21616
stacklevel=4,
21717
)
18+
19+
20+
def create_call_wrapper(func: Callable[..., Any], request_type: type) -> tuple[Callable[[Any], Any], bool]:
21+
"""
22+
Create a wrapper function that knows how to call func with the request object.
23+
24+
Returns a tuple of (wrapper_func, should_deprecate):
25+
- wrapper_func: A function that takes the request and calls func appropriately
26+
- should_deprecate: True if a deprecation warning should be issued
27+
28+
The wrapper handles three calling patterns:
29+
1. Positional-only parameter typed as request_type (no default): func(req)
30+
2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req})
31+
3. No request parameter or parameter with default (deprecated): func()
32+
"""
33+
try:
34+
sig = inspect.signature(func)
35+
type_hints = get_type_hints(func)
36+
except (ValueError, TypeError, NameError):
37+
# Can't inspect signature or resolve type hints, assume no request parameter (deprecated)
38+
return lambda _: func(), True
39+
40+
# Check for positional-only parameter typed as request_type
41+
for param_name, param in sig.parameters.items():
42+
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
43+
param_type = type_hints.get(param_name)
44+
if param_type == request_type:
45+
# Check if it has a default - if so, treat as old style (deprecated)
46+
if param.default is not inspect.Parameter.empty:
47+
return lambda _: func(), True
48+
# Found positional-only parameter with correct type and no default
49+
return lambda req: func(req), False
50+
51+
# Check for any positional/keyword parameter typed as request_type
52+
for param_name, param in sig.parameters.items():
53+
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
54+
param_type = type_hints.get(param_name)
55+
if param_type == request_type:
56+
# Check if it has a default - if so, treat as old style (deprecated)
57+
if param.default is not inspect.Parameter.empty:
58+
return lambda _: func(), True
59+
60+
# Found keyword parameter with correct type and no default
61+
# Need to capture param_name in closure properly
62+
def make_keyword_wrapper(name: str) -> Callable[[Any], Any]:
63+
return lambda req: func(**{name: req})
64+
65+
return make_keyword_wrapper(param_name), False
66+
67+
# No request parameter found - use old style (deprecated)
68+
return lambda _: func(), True

0 commit comments

Comments
 (0)