6
6
import json
7
7
import urllib .parse
8
8
from io import BytesIO
9
+ from json import dumps
9
10
from typing import Any , Optional , Union
10
11
from typing_extensions import Literal
11
12
12
13
from graphql import ExecutionResult
13
- from webob import Request , Response
14
+ from urllib3 import encode_multipart_formdata
14
15
15
16
from graphql_server .http import GraphQLHTTPResponse
16
17
from graphql_server .http .ides import GraphQL_IDE
17
18
from graphql_server .webob import GraphQLView as BaseGraphQLView
18
19
from tests .http .context import get_context
19
20
from tests .views .schema import Query , schema
21
+ from webob import Request , Response
20
22
21
- from .base import JSON , HttpClient , Response as ClientResponse , ResultOverrideFunction
23
+ from .base import JSON , HttpClient , ResultOverrideFunction
24
+ from .base import Response as ClientResponse
22
25
23
26
24
27
class GraphQLView (BaseGraphQLView [dict [str , object ], object ]):
@@ -82,18 +85,16 @@ async def _graphql_request(
82
85
83
86
url = "/graphql"
84
87
85
- if body and files :
86
- body .update ({name : (file , name ) for name , file in files .items ()})
88
+ headers = self ._get_headers (method = method , headers = headers , files = files )
87
89
88
90
if method == "get" :
89
91
body_encoded = urllib .parse .urlencode (body or {})
90
92
url = f"{ url } ?{ body_encoded } "
91
- else :
92
- if body :
93
- data = body if files else json .dumps (body )
94
- kwargs ["body" ] = data
95
-
96
- headers = self ._get_headers (method = method , headers = headers , files = files )
93
+ elif body :
94
+ if files :
95
+ header_pairs , body = create_multipart_request_body (body , files )
96
+ headers = dict (header_pairs )
97
+ kwargs ["body" ] = body
97
98
98
99
return await self .request (url , method , headers = headers , ** kwargs )
99
100
@@ -104,9 +105,11 @@ def _do_request(
104
105
headers : Optional [dict [str , str ]] = None ,
105
106
** kwargs : Any ,
106
107
) -> ClientResponse :
107
- body = kwargs .get ("body" , None )
108
+ body = kwargs .pop ("body" , None )
109
+ if isinstance (body , dict ):
110
+ body = json .dumps (body ).encode ("utf-8" )
108
111
req = Request .blank (
109
- url , method = method .upper (), headers = headers or {}, body = body
112
+ url , method = method .upper (), headers = headers or {}, body = body , ** kwargs
110
113
)
111
114
resp = self .view .dispatch_request (req )
112
115
return ClientResponse (
@@ -139,5 +142,26 @@ async def post(
139
142
json : Optional [JSON ] = None ,
140
143
headers : Optional [dict [str , str ]] = None ,
141
144
) -> ClientResponse :
142
- body = json if json is not None else data
145
+ body = dumps ( json ). encode ( "utf-8" ) if json is not None else data
143
146
return await self .request (url , "post" , headers = headers , body = body )
147
+
148
+
149
+ def create_multipart_request_body (
150
+ body : dict [str , object ], files : dict [str , BytesIO ]
151
+ ) -> tuple [list [tuple [str , str ]], bytes ]:
152
+ fields = {
153
+ "operations" : body ["operations" ],
154
+ "map" : body ["map" ],
155
+ }
156
+
157
+ for filename , data in files .items ():
158
+ fields [filename ] = (filename , data .read ().decode (), "text/plain" )
159
+
160
+ request_body , content_type_header = encode_multipart_formdata (fields )
161
+
162
+ headers = [
163
+ ("Content-Type" , content_type_header ),
164
+ ("Content-Length" , f"{ len (request_body )} " ),
165
+ ]
166
+
167
+ return headers , request_body
0 commit comments