Skip to content

Commit e97a17e

Browse files
jkimbopatrick91
authored andcommitted
Some cleanup to keep all logic in custom ExecutionContext
1 parent 9d8962a commit e97a17e

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

strawberry/schema/execute.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from inspect import isawaitable
33
from typing import Any, Awaitable, Dict, List, Optional, Sequence, Type, cast
44

5-
from promise import Promise, is_thenable
6-
75
from graphql import (
86
ExecutionContext as GraphQLExecutionContext,
97
ExecutionResult as GraphQLExecutionResult,
@@ -12,7 +10,6 @@
1210
execute as original_execute,
1311
parse,
1412
)
15-
from graphql.pyutils import is_awaitable as default_is_awaitable
1613
from graphql.validation import validate
1714

1815
from strawberry.extensions import Extension
@@ -99,12 +96,6 @@ async def execute(
9996
)
10097

10198

102-
def is_awaitable(value):
103-
if is_thenable(value):
104-
return False
105-
return default_is_awaitable(value)
106-
107-
10899
def execute_sync(
109100
schema: GraphQLSchema,
110101
query: str,

strawberry/schema/execute_context.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,23 @@
88
GraphQLOutputType,
99
GraphQLResolveInfo,
1010
)
11-
from graphql.language import FieldNode
12-
from graphql.pyutils import AwaitableOrValue, Path, Undefined
11+
from graphql.language import FieldNode, OperationDefinitionNode
12+
from graphql.pyutils import (
13+
AwaitableOrValue,
14+
Path,
15+
Undefined,
16+
is_awaitable as default_is_awaitable,
17+
)
18+
19+
20+
def is_awaitable(value):
21+
"""
22+
Create custom is_awaitable function to make sure that Promises' aren't
23+
considered awaitable
24+
"""
25+
if is_thenable(value):
26+
return False
27+
return default_is_awaitable(value)
1328

1429

1530
S = TypeVar("S")
@@ -31,6 +46,20 @@ def handle_success(resolved_values: List[S]) -> Dict[Hashable, S]:
3146

3247

3348
class ExecutionContextWithPromise(ExecutionContext):
49+
is_awaitable = staticmethod(is_awaitable)
50+
51+
def execute_operation(
52+
self, operation: OperationDefinitionNode, root_value: Any
53+
) -> Optional[AwaitableOrValue[Any]]:
54+
# Wrap execute in a Promise
55+
original_execute_operation = super().execute_operation
56+
57+
def promise_executor(v):
58+
return original_execute_operation(operation, root_value)
59+
60+
promise = Promise.resolve(None).then(promise_executor)
61+
return promise
62+
3463
def build_response(self, data):
3564
if is_thenable(data):
3665
original_build_response = super().build_response
@@ -43,7 +72,7 @@ def on_resolve(data):
4372
return original_build_response(data)
4473

4574
promise = data.catch(on_rejected).then(on_resolve)
46-
return promise
75+
return promise.get()
4776
return super().build_response(data)
4877

4978
def complete_value_catching_error(

0 commit comments

Comments
 (0)