Skip to content

Commit ed8d4f8

Browse files
Added testing for src/AWS.jl
- Mock AWS._http_request() - Minor string performance improvements - Formatting styles to align with BlueStyle - AWS sign functions return the request they are modifying - Service URL generation takes in region::String rather than config::AWSConfig - Added return_headers as kwarg for AWS Requests - Added testing for src/AWS.jl functionality
1 parent 5e623d2 commit ed8d4f8

File tree

4 files changed

+397
-83
lines changed

4 files changed

+397
-83
lines changed

src/AWS.jl

Lines changed: 42 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using AWSCore
44
using Base64
55
using HTTP
66
using MbedTLS
7+
using Mocking
78
using OrderedCollections: LittleDict, OrderedDict
89
using Retry
910
using Sockets
@@ -97,10 +98,11 @@ function _sign_aws2!(aws::AWSConfig, request::LittleDict, time::DateTime)
9798
query = Pair[k => query[k] for k in sort(collect(keys(query)))]
9899
uri = HTTP.URI(request[:url])
99100
to_sign = "POST\n$(uri.host)\n$(uri.path)\n$(HTTP.escapeuri(query))"
100-
secret = creds.secret_key
101-
push!(query, "Signature" => digest(MD_SHA256, to_sign, secret) |> base64encode |> strip)
101+
push!(query, "Signature" => digest(MD_SHA256, to_sign, creds.secret_key) |> base64encode |> strip)
102102

103103
request[:content] = HTTP.escapeuri(query)
104+
105+
return request
104106
end
105107

106108
function _sign_aws4!(aws::AWSConfig, request::LittleDict, time::DateTime)
@@ -114,7 +116,7 @@ function _sign_aws4!(aws::AWSConfig, request::LittleDict, time::DateTime)
114116
authentication_scope = [date, aws.region, request[:service], "aws4_request"]
115117

116118
creds = check_credentials(aws.credentials)
117-
signing_key = string("AWS4", creds.secret_key)
119+
signing_key = "AWS4$(creds.secret_key)"
118120

119121
for scope in authentication_scope
120122
signing_key = digest(MD_SHA256, scope, signing_key)
@@ -130,11 +132,11 @@ function _sign_aws4!(aws::AWSConfig, request::LittleDict, time::DateTime)
130132
delete!(request[:headers], "Authorization")
131133
merge!(request[:headers], Dict(
132134
"x-amz-content-sha256" => content_hash,
133-
"x-amz-date" => datetime,
134-
"Content-MD5" => base64encode(digest(MD_MD5, request[:content]))
135+
"x-amz-date" => datetime,
136+
"Content-MD5" => base64encode(digest(MD_MD5, request[:content]))
135137
))
136138

137-
if creds.token != ""
139+
if !isempty(creds.token)
138140
request[:headers]["x-amz-security-token"] = creds.token
139141
end
140142

@@ -148,12 +150,13 @@ function _sign_aws4!(aws::AWSConfig, request::LittleDict, time::DateTime)
148150
query = Pair[k => query[k] for k in sort(collect(keys(query)))]
149151

150152
# Create hash of canonical request...
151-
canonical_form = string(request[:request_method], "\n",
152-
request[:service] == "s3" ? uri.path : HTTP.escapepath(uri.path), "\n",
153-
HTTP.escapeuri(query), "\n",
154-
join(sort(canonical_headers), "\n"), "\n\n",
155-
signed_headers, "\n",
156-
content_hash
153+
canonical_form = string(
154+
request[:request_method], "\n",
155+
request[:service] == "s3" ? uri.path : HTTP.escapepath(uri.path), "\n",
156+
HTTP.escapeuri(query), "\n",
157+
join(sort(canonical_headers), "\n"), "\n\n",
158+
signed_headers, "\n",
159+
content_hash
157160
)
158161

159162
canonical_hash = bytes2hex(digest(MD_SHA256, canonical_form))
@@ -169,6 +172,8 @@ function _sign_aws4!(aws::AWSConfig, request::LittleDict, time::DateTime)
169172
"SignedHeaders=$signed_headers, ",
170173
"Signature=$signature"
171174
)
175+
176+
return request
172177
end
173178

174179
function _http_request(aws::AWSConfig, request::LittleDict)
@@ -211,10 +216,11 @@ function _http_request(aws::AWSConfig, request::LittleDict)
211216
end
212217
end
213218

214-
function do_request(aws::AWSConfig, request::LittleDict; return_headers=false)
219+
function do_request(aws::AWSConfig, request::AbstractDict; return_headers=false)
215220
response = nothing
216221
TOO_MANY_REQUESTS = 429
217-
REDIRECT_CODES = [301, 302, 303, 304, 305, 307, 308]
222+
EXPIRED_ERROR_CODES = ["ExpiredToken", "ExpiredTokenException", "RequestExpired"]
223+
REDIRECT_ERROR_CODES = [301, 302, 303, 304, 305, 307, 308]
218224
THROTTLING_ERROR_CODES = [
219225
"Throttling",
220226
"ThrottlingException",
@@ -236,9 +242,9 @@ function do_request(aws::AWSConfig, request::LittleDict; return_headers=false)
236242

237243
@repeat 3 try
238244
aws.credentials === nothing || _sign!(aws, request)
239-
response = _http_request(aws, request)
245+
response = @mock _http_request(aws, request)
240246

241-
if response.status in REDIRECT_CODES && HTTP.header(response, "Location") != ""
247+
if response.status in REDIRECT_ERROR_CODES && HTTP.header(response, "Location") != ""
242248
request[:url] = HTTP.header(response, "Location")
243249
continue
244250
end
@@ -251,8 +257,7 @@ function do_request(aws::AWSConfig, request::LittleDict; return_headers=false)
251257

252258
# Handle ExpiredToken...
253259
# https://github.com/aws/aws-sdk-go/blob/v1.31.5/aws/request/retryer.go#L98
254-
expired_error_codes = ["ExpiredToken", "ExpiredTokenException", "RequestExpired"]
255-
@retry if ecode(e) in expired_error_codes
260+
@retry if ecode(e) in EXPIRED_ERROR_CODES
256261
check_credentials(aws.credentials, force_refresh=true)
257262
end
258263

@@ -302,44 +307,6 @@ function do_request(aws::AWSConfig, request::LittleDict; return_headers=false)
302307
return (return_headers ? (xml_dict(body, xml_dict_type), Dict(response.headers)) : xml_dict(body, xml_dict_type))
303308
end
304309

305-
if occursin(r"/x-amz-json-1.[01]$", mime)
306-
if isempty(response.body)
307-
return nothing
308-
end
309-
310-
if get(r, :ordered_json_dict, true)
311-
return (return_headers ? (JSON.parse(body, dicttype=OrderedDict), Dict(response.headers)) : JSON.parse(body, dicttype=OrderedDict))
312-
else
313-
return (return_headers ? (JSON.parse(body), Dict(response.headers)) : JSON.parse(body))
314-
end
315-
end
316-
317-
if occursin(r"json$", mime)
318-
if isempty(response.body)
319-
return nothing
320-
end
321-
322-
if get(request, :ordered_json_dict, true)
323-
info = JSON.parse(body, dicttype=OrderedDict)
324-
else
325-
info = JSON.parse(body)
326-
end
327-
328-
@protected try
329-
action = request[:query]["Action"]
330-
info = info[action * "Response"]
331-
info = info[action * "Result"]
332-
catch e
333-
@ignore if typeof(e) == KeyError end
334-
end
335-
336-
return (return_headers ? (info, Dict(response.headers)) : info)
337-
end
338-
339-
if occursin(r"^text/", mime)
340-
return (return_headers ? (body, Dict(response.headers)) : body)
341-
end
342-
343310
# Return raw data by default...
344311
return (return_headers ? (response.body, response.headers) : response.body)
345312
end
@@ -356,17 +323,26 @@ function _generate_rest_resource(request_uri::String, args::Dict{String, Any})
356323
return request_uri
357324
end
358325

359-
function _generate_service_url(aws::AWSConfig, request::LittleDict)
326+
function _generate_service_url(region::String, request::AbstractDict)
327+
SERVICE_HOST = "amazonaws.com"
360328
endpoint = get(request, :endpoint, request[:service])
361-
region = aws.region
362329
regionless_endpoints = ("iam", "route53")
363-
service_host = "amazonaws.com"
364330

365331
if endpoint in regionless_endpoints || (endpoint == "sdb" && region == "us-east-1")
366332
region = ""
367333
end
368334

369-
return string("https://", endpoint, ".", isempty(region) ? "" : "$region.", service_host, request[:resource])
335+
return string("https://", endpoint, ".", isempty(region) ? "" : "$region.", SERVICE_HOST, request[:resource])
336+
end
337+
338+
function _return_headers(args::Dict{String, Any})
339+
return_headers = get(args, "return_headers", false)
340+
341+
if return_headers
342+
delete!(args, "return_headers")
343+
end
344+
345+
return return_headers
370346
end
371347

372348
"""
@@ -389,10 +365,7 @@ Perform a RestXML request to AWS
389365
# Returns
390366
- The response from AWS
391367
"""
392-
function (service::RestXMLService)(
393-
aws::AWSConfig, request_method::String, request_uri::String, args=[];
394-
return_headers=false,
395-
)
368+
function (service::RestXMLService)(aws::AWSConfig, request_method::String, request_uri::String, args=[])
396369
request = LittleDict()
397370
args = stringdict(args)
398371
request[:service] = service.name
@@ -405,6 +378,8 @@ function (service::RestXMLService)(
405378
delete!(args, "headers")
406379
delete!(args, "Body")
407380

381+
return_headers = _return_headers(args)
382+
408383
request[:resource] = _generate_rest_resource(request_uri, args)
409384

410385
query_str = HTTP.escapeuri(args)
@@ -418,21 +393,14 @@ function (service::RestXMLService)(
418393
end
419394
end
420395

421-
request[:url] = _generate_service_url(aws, request)
396+
request[:url] = _generate_service_url(aws.region, request)
422397

423398
return do_request(aws, request; return_headers=return_headers)
424399
end
425-
426-
function (service::RestXMLService)(
427-
request_method::String, request_uri::String, args=[];
428-
return_headers=false,
429-
)
430-
return service(AWSConfig(), request_method, request_uri, args; return_headers=return_headers)
431-
end
432-
400+
(service::RestXMLService)(request_method::String, request_uri::String, args=[]) = service(AWSConfig(), request_method, request_uri, args)
433401
(service::RestXMLService)(a...; b...) = service(a..., b)
434402

435-
function (service::QueryService)(aws, operation, args=[])
403+
function (service::QueryService)(aws::AWS.AWSConfig, operation, args=[])
436404
return AWSCore.service_query(
437405
aws;
438406
service=service.name,

0 commit comments

Comments
 (0)