Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions src/shake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,41 @@ function transform!(context::T) where {T<:SHAKE}
end
function digest!(context::T,d::UInt,p::Ptr{UInt8}) where {T<:SHAKE}
usedspace = context.bytecount % blocklen(T)
# If we have anything in the buffer still, pad and transform that data
if usedspace < blocklen(T) - 1
# Begin padding with a 0x1f
context.buffer[usedspace+1] = 0x1f
# Fill with zeros up until the last byte
context.buffer[usedspace+2:end-1] .= 0x00
# Finish it off with a 0x80
context.buffer[end] = 0x80
else
# Otherwise, we have to add on a whole new buffer
context.buffer[end] = 0x9f
if !context.used
# If we have anything in the buffer still, pad and transform that data
if usedspace < blocklen(T) - 1
# Begin padding with a 0x1f
context.buffer[usedspace+1] = 0x1f
# Fill with zeros up until the last byte
context.buffer[usedspace+2:end-1] .= 0x00
# Finish it off with a 0x80
context.buffer[end] = 0x80
else
# Otherwise, we have to add on a whole new buffer
context.buffer[end] = 0x9f
end
# Final transform:
transform!(context)

context.used = true
context.bytecount = 0
usedspace = 0
end
# Final transform:
transform!(context)
# Return the digest:
# fill the given memory via pointer, if d>blocklen, update pointer and digest again.
if d <= blocklen(T)
for i = 1:d
unsafe_store!(p,reinterpret(UInt8, context.state)[i],i)
end
return
else
for i = 1:blocklen(T)
unsafe_store!(p,reinterpret(UInt8, context.state)[i],i)
end
context.used = true
p+=blocklen(T)
next_d_len = UInt(d - blocklen(T))
digest!(context, next_d_len, p)
return
while d > 0
avail = blocklen(T) - usedspace
len = min(d, avail)
for i = 1:len
unsafe_store!(p,reinterpret(UInt8, context.state)[usedspace+i],i)
end
context.bytecount += len
p += len
d = UInt(d - len)
if len == avail
transform!(context)
usedspace = context.bytecount % blocklen(T)
end
end
end

Expand Down
54 changes: 54 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,70 @@ end
@testset "shake128" begin
for (k,v) in SHA128test
@test SHA.shake128(hex2bytes(k[1]),k[2]) == hex2bytes(v)
ctx = SHAKE_128_CTX()
in = hex2bytes(k[1])
idx = 1
while idx <= length(in)
l = min(rand(1:length(in)), length(in) - idx + 1)
update!(ctx, in[idx:idx+l-1])
idx += l
end
out = Vector{UInt8}(undef, k[2])
idx = 0
while idx < k[2]
l = min(rand(1:k[2]), k[2] - idx)
digest!(ctx, l, pointer(out) + idx)
idx += l
end
@test out == hex2bytes(v)
end
@test SHA.shake128(b"",UInt(16)) == hex2bytes("7f9c2ba4e88f827d616045507605853e")
@test SHA.shake128(codeunits("0" ^ 167), UInt(32)) == hex2bytes("ff60b0516fb8a3d4032900976e98b5595f57e9d4a88a0e37f7cc5adfa3c47da2")

for chunksize in UInt[1, 2, 3, 200]
ctx = SHAKE_128_CTX()
out = Vector{UInt8}(undef, 10000)
idx = 0
while idx < length(out)
digest!(ctx, chunksize, pointer(out) + idx)
idx += chunksize
end
@test out == SHA.shake128(UInt8[], UInt(length(out)))
end
end

@testset "shake256" begin
for (k,v) in SHA256test
@test SHA.shake256(hex2bytes(k[1]),k[2]) == hex2bytes(v)
ctx = SHAKE_256_CTX()
in = hex2bytes(k[1])
idx = 1
while idx <= length(in)
l = min(rand(1:length(in)), length(in) - idx + 1)
update!(ctx, in[idx:idx+l-1])
idx += l
end
out = Vector{UInt8}(undef, k[2])
idx = 0
while idx < k[2]
l = min(rand(1:k[2]), k[2] - idx)
digest!(ctx, l, pointer(out) + idx)
idx += l
end
@test out == hex2bytes(v)
end
@test SHA.shake256(b"",UInt(32)) == hex2bytes("46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762f")
@test SHA.shake256(codeunits("0"^135),UInt(32)) == hex2bytes("ab11f61b5085a108a58670a66738ea7a8d8ce23b7c57d64de83eaafb10923cf8")

for chunksize in UInt[1, 2, 3, 200]
ctx = SHAKE_256_CTX()
out = Vector{UInt8}(undef, 10000)
idx = 0
while idx < length(out)
digest!(ctx, chunksize, pointer(out) + idx)
idx += chunksize
end
@test out == SHA.shake256(UInt8[], UInt(length(out)))
end
end
end