Skip to content

Commit 3ec65b0

Browse files
author
Ashutosh Gupta
committed
Add support to hash via GPUs
Signed-off-by: Ashutosh Gupta <[email protected]>
1 parent ccbcbce commit 3ec65b0

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ otel = [
5959
"opentelemetry-exporter-otlp",
6060
]
6161

62+
gpu = [
63+
"torch",
64+
]
65+
6266
[project.scripts]
6367
model_signing = "model_signing._cli:main"
6468

src/model_signing/_hashing/gpu.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2024 The Sigstore Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""GPU-enabled hashing engines.
15+
16+
This module contains hashing engines backed by PyTorch. When a CUDA capable
17+
GPU is available, the hashing computation runs on the GPU. Otherwise, the
18+
implementation gracefully falls back to CPU execution while exposing the same
19+
API.
20+
21+
The hashing algorithm implemented is SHA256 and mirrors the one from
22+
:mod:`hashlib`. The implementation is self-contained and relies only on basic
23+
PyTorch tensor operations, making it suitable for execution on both CPU and
24+
GPU devices.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import importlib
30+
from typing import Any
31+
32+
from typing_extensions import override
33+
34+
from model_signing._hashing import hashing
35+
36+
37+
# PyTorch lacks bitwise shifts for unsigned tensors, so we operate on int64 and
38+
# mask values to 32 bits explicitly.
39+
# NOTE: `_MASK32` is a Python integer so that it can be broadcast on any device
40+
# without an explicit copy.
41+
_MASK32 = 0xFFFFFFFF
42+
43+
44+
_K_VALUES = [
45+
0x428A2F98,
46+
0x71374491,
47+
0xB5C0FBCF,
48+
0xE9B5DBA5,
49+
0x3956C25B,
50+
0x59F111F1,
51+
0x923F82A4,
52+
0xAB1C5ED5,
53+
0xD807AA98,
54+
0x12835B01,
55+
0x243185BE,
56+
0x550C7DC3,
57+
0x72BE5D74,
58+
0x80DEB1FE,
59+
0x9BDC06A7,
60+
0xC19BF174,
61+
0xE49B69C1,
62+
0xEFBE4786,
63+
0x0FC19DC6,
64+
0x240CA1CC,
65+
0x2DE92C6F,
66+
0x4A7484AA,
67+
0x5CB0A9DC,
68+
0x76F988DA,
69+
0x983E5152,
70+
0xA831C66D,
71+
0xB00327C8,
72+
0xBF597FC7,
73+
0xC6E00BF3,
74+
0xD5A79147,
75+
0x06CA6351,
76+
0x14292967,
77+
0x27B70A85,
78+
0x2E1B2138,
79+
0x4D2C6DFC,
80+
0x53380D13,
81+
0x650A7354,
82+
0x766A0ABB,
83+
0x81C2C92E,
84+
0x92722C85,
85+
0xA2BFE8A1,
86+
0xA81A664B,
87+
0xC24B8B70,
88+
0xC76C51A3,
89+
0xD192E819,
90+
0xD6990624,
91+
0xF40E3585,
92+
0x106AA070,
93+
0x19A4C116,
94+
0x1E376C08,
95+
0x2748774C,
96+
0x34B0BCB5,
97+
0x391C0CB3,
98+
0x4ED8AA4A,
99+
0x5B9CCA4F,
100+
0x682E6FF3,
101+
0x748F82EE,
102+
0x78A5636F,
103+
0x84C87814,
104+
0x8CC70208,
105+
0x90BEFFFA,
106+
0xA4506CEB,
107+
0xBEF9A3F7,
108+
0xC67178F2,
109+
]
110+
111+
112+
_H0_VALUES = [
113+
0x6A09E667,
114+
0xBB67AE85,
115+
0x3C6EF372,
116+
0xA54FF53A,
117+
0x510E527F,
118+
0x9B05688C,
119+
0x1F83D9AB,
120+
0x5BE0CD19,
121+
]
122+
123+
124+
_TORCH: Any | None = None
125+
126+
127+
def _ensure_torch() -> Any:
128+
"""Import :mod:`torch` on demand."""
129+
global _TORCH
130+
if _TORCH is None:
131+
try:
132+
_TORCH = importlib.import_module("torch")
133+
except ModuleNotFoundError as exc:
134+
raise ModuleNotFoundError(
135+
"TorchSHA256 requires the optional 'torch' dependency; install "
136+
"with `pip install model-signing[gpu]`"
137+
) from exc
138+
return _TORCH
139+
140+
141+
def _rotr(x: Any, n: int, torch: Any) -> Any:
142+
"""Right rotation for 32-bit tensors."""
143+
return ((x >> n) | (x << (32 - n))) & _MASK32
144+
145+
146+
def _sha256_torch(data: bytes, device: Any) -> bytes:
147+
"""Pure PyTorch SHA256 implementation."""
148+
torch = _ensure_torch()
149+
msg = bytearray(data)
150+
bit_len = (len(msg) * 8) & 0xFFFFFFFFFFFFFFFF
151+
msg.append(0x80)
152+
while (len(msg) * 8) % 512 != 448:
153+
msg.append(0)
154+
msg.extend(bit_len.to_bytes(8, "big"))
155+
156+
words = torch.tensor(
157+
[int.from_bytes(msg[i : i + 4], "big") for i in range(0, len(msg), 4)],
158+
dtype=torch.int64,
159+
device=device,
160+
)
161+
162+
k = torch.tensor(_K_VALUES, dtype=torch.int64, device=device)
163+
h = torch.tensor(_H0_VALUES, dtype=torch.int64, device=device)
164+
for chunk_start in range(0, words.shape[0], 16):
165+
w = torch.zeros(64, dtype=torch.int64, device=device)
166+
w[:16] = words[chunk_start : chunk_start + 16]
167+
for i in range(16, 64):
168+
s0 = (
169+
_rotr(w[i - 15], 7, torch)
170+
^ _rotr(w[i - 15], 18, torch)
171+
^ (w[i - 15] >> 3)
172+
)
173+
s1 = (
174+
_rotr(w[i - 2], 17, torch)
175+
^ _rotr(w[i - 2], 19, torch)
176+
^ (w[i - 2] >> 10)
177+
)
178+
w[i] = (w[i - 16] + s0 + w[i - 7] + s1) & _MASK32
179+
180+
a, b, c, d, e, f, g, hv = h
181+
for i in range(64):
182+
s1 = _rotr(e, 6, torch) ^ _rotr(e, 11, torch) ^ _rotr(e, 25, torch)
183+
ch = (e & f) ^ (((~e) & _MASK32) & g)
184+
temp1 = (hv + s1 + ch + k[i] + w[i]) & _MASK32
185+
s0 = _rotr(a, 2, torch) ^ _rotr(a, 13, torch) ^ _rotr(a, 22, torch)
186+
maj = (a & b) ^ (a & c) ^ (b & c)
187+
temp2 = (s0 + maj) & _MASK32
188+
189+
hv = g
190+
g = f
191+
f = e
192+
e = (d + temp1) & _MASK32
193+
d = c
194+
c = b
195+
b = a
196+
a = (temp1 + temp2) & _MASK32
197+
198+
h = (h + torch.stack([a, b, c, d, e, f, g, hv])) & _MASK32
199+
200+
return b"".join(int(x.item()).to_bytes(4, "big") for x in h)
201+
202+
203+
class TorchSHA256(hashing.StreamingHashEngine):
204+
"""SHA256 hashing engine powered by PyTorch."""
205+
206+
def __init__(self, initial_data: bytes = b"", device: Any | None = None):
207+
torch = _ensure_torch()
208+
self._buffer = bytearray(initial_data)
209+
if device is None:
210+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211+
self._device = torch.device(device)
212+
213+
@override
214+
def update(self, data: bytes) -> None:
215+
self._buffer.extend(data)
216+
217+
@override
218+
def reset(self, data: bytes = b"") -> None:
219+
self._buffer = bytearray(data)
220+
221+
@override
222+
def compute(self) -> hashing.Digest:
223+
digest_bytes = _sha256_torch(bytes(self._buffer), self._device)
224+
return hashing.Digest(self.digest_name, digest_bytes)
225+
226+
@property
227+
@override
228+
def digest_name(self) -> str:
229+
return "sha256"
230+
231+
@property
232+
@override
233+
def digest_size(self) -> int:
234+
return 32

tests/_hashing/gpu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Tests for GPU backed hashing engines."""
2+
3+
# Copyright 2024 The Sigstore Authors
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
from model_signing._hashing.gpu import TorchSHA256
20+
from model_signing._hashing.memory import SHA256
21+
22+
23+
pytest.importorskip("torch")
24+
25+
26+
def test_torch_sha256_matches_hashlib():
27+
data = b"sigstore"
28+
gpu_hasher = TorchSHA256(data)
29+
cpu_hasher = SHA256(data)
30+
gpu_digest = gpu_hasher.compute()
31+
cpu_digest = cpu_hasher.compute()
32+
assert gpu_digest.digest_value == cpu_digest.digest_value
33+
assert gpu_digest.algorithm == cpu_digest.algorithm == "sha256"

0 commit comments

Comments
 (0)