Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ otel = [
"opentelemetry-exporter-otlp",
]

gpu = [
"torch",
]

[project.scripts]
model_signing = "model_signing._cli:main"

Expand Down
236 changes: 236 additions & 0 deletions src/model_signing/_hashing/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright 2024 The Sigstore Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPU-enabled hashing engines.

This module contains hashing engines backed by PyTorch. When a CUDA capable
GPU is available, the hashing computation runs on the GPU. Otherwise, the
implementation gracefully falls back to CPU execution while exposing the same
API.

The hashing algorithm implemented is SHA256 and mirrors the one from
:mod:`hashlib`. The implementation is self-contained and relies only on basic
PyTorch tensor operations, making it suitable for execution on both CPU and
GPU devices.
"""

from __future__ import annotations

import importlib
from typing import Any

from typing_extensions import override

from model_signing._hashing import hashing


# PyTorch lacks bitwise shifts for unsigned tensors, so we operate on int64 and
# mask values to 32 bits explicitly.
# NOTE: `_MASK32` is a Python integer so that it can be broadcast on any device
# without an explicit copy.
_MASK32 = 0xFFFFFFFF


_K_VALUES = [
0x428A2F98,
0x71374491,
0xB5C0FBCF,
0xE9B5DBA5,
0x3956C25B,
0x59F111F1,
0x923F82A4,
0xAB1C5ED5,
0xD807AA98,
0x12835B01,
0x243185BE,
0x550C7DC3,
0x72BE5D74,
0x80DEB1FE,
0x9BDC06A7,
0xC19BF174,
0xE49B69C1,
0xEFBE4786,
0x0FC19DC6,
0x240CA1CC,
0x2DE92C6F,
0x4A7484AA,
0x5CB0A9DC,
0x76F988DA,
0x983E5152,
0xA831C66D,
0xB00327C8,
0xBF597FC7,
0xC6E00BF3,
0xD5A79147,
0x06CA6351,
0x14292967,
0x27B70A85,
0x2E1B2138,
0x4D2C6DFC,
0x53380D13,
0x650A7354,
0x766A0ABB,
0x81C2C92E,
0x92722C85,
0xA2BFE8A1,
0xA81A664B,
0xC24B8B70,
0xC76C51A3,
0xD192E819,
0xD6990624,
0xF40E3585,
0x106AA070,
0x19A4C116,
0x1E376C08,
0x2748774C,
0x34B0BCB5,
0x391C0CB3,
0x4ED8AA4A,
0x5B9CCA4F,
0x682E6FF3,
0x748F82EE,
0x78A5636F,
0x84C87814,
0x8CC70208,
0x90BEFFFA,
0xA4506CEB,
0xBEF9A3F7,
0xC67178F2,
]


_H0_VALUES = [
0x6A09E667,
0xBB67AE85,
0x3C6EF372,
0xA54FF53A,
0x510E527F,
0x9B05688C,
0x1F83D9AB,
0x5BE0CD19,
]


_TORCH: Any | None = None


def _ensure_torch() -> Any:
"""Import :mod:`torch` on demand."""
global _TORCH
if _TORCH is None:
try:
_TORCH = importlib.import_module("torch")
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"TorchSHA256 requires the optional 'torch' dependency; install "
"with `pip install model-signing[gpu]`"
) from exc
return _TORCH

Check warning on line 138 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following lines were not covered in your tests: 130 to 138,


def _rotr(x: Any, n: int, torch: Any) -> Any:
"""Right rotation for 32-bit tensors."""
return ((x >> n) | (x << (32 - n))) & _MASK32

Check warning on line 143 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 143,


def _sha256_torch(data: bytes, device: Any) -> bytes:
"""Pure PyTorch SHA256 implementation."""
torch = _ensure_torch()
msg = bytearray(data)
bit_len = (len(msg) * 8) & 0xFFFFFFFFFFFFFFFF
msg.append(0x80)
while (len(msg) * 8) % 512 != 448:
msg.append(0)
msg.extend(bit_len.to_bytes(8, "big"))

words = torch.tensor(
[int.from_bytes(msg[i : i + 4], "big") for i in range(0, len(msg), 4)],
dtype=torch.int64,
device=device,
)

k = torch.tensor(_K_VALUES, dtype=torch.int64, device=device)
h = torch.tensor(_H0_VALUES, dtype=torch.int64, device=device)
for chunk_start in range(0, words.shape[0], 16):
w = torch.zeros(64, dtype=torch.int64, device=device)
w[:16] = words[chunk_start : chunk_start + 16]
for i in range(16, 64):
s0 = (
_rotr(w[i - 15], 7, torch)
^ _rotr(w[i - 15], 18, torch)
^ (w[i - 15] >> 3)
)
s1 = (
_rotr(w[i - 2], 17, torch)
^ _rotr(w[i - 2], 19, torch)
^ (w[i - 2] >> 10)
)
w[i] = (w[i - 16] + s0 + w[i - 7] + s1) & _MASK32

a, b, c, d, e, f, g, hv = h
for i in range(64):
s1 = _rotr(e, 6, torch) ^ _rotr(e, 11, torch) ^ _rotr(e, 25, torch)
ch = (e & f) ^ (((~e) & _MASK32) & g)
temp1 = (hv + s1 + ch + k[i] + w[i]) & _MASK32
s0 = _rotr(a, 2, torch) ^ _rotr(a, 13, torch) ^ _rotr(a, 22, torch)
maj = (a & b) ^ (a & c) ^ (b & c)
temp2 = (s0 + maj) & _MASK32

hv = g
g = f
f = e
e = (d + temp1) & _MASK32
d = c
c = b
b = a
a = (temp1 + temp2) & _MASK32

h = (h + torch.stack([a, b, c, d, e, f, g, hv])) & _MASK32

return b"".join(int(x.item()).to_bytes(4, "big") for x in h)

Check warning on line 200 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following lines were not covered in your tests: 148 to 200,


class TorchSHA256(hashing.StreamingHashEngine):
"""SHA256 hashing engine powered by PyTorch."""

def __init__(self, initial_data: bytes = b"", device: Any | None = None):
torch = _ensure_torch()
self._buffer = bytearray(initial_data)
if device is None:
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self._device = torch.device(device)

Check warning on line 213 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following lines were not covered in your tests: 207 to 213,

@override
def update(self, data: bytes) -> None:
self._buffer.extend(data)

Check warning on line 217 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 217,

@override
def reset(self, data: bytes = b"") -> None:
self._buffer = bytearray(data)

Check warning on line 221 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 221,

@override
def compute(self) -> hashing.Digest:
digest_bytes = _sha256_torch(bytes(self._buffer), self._device)
return hashing.Digest(self.digest_name, digest_bytes)

Check warning on line 226 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following lines were not covered in your tests: 225 to 226,

@property
@override
def digest_name(self) -> str:
return "sha256"

Check warning on line 231 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 231,

@property
@override
def digest_size(self) -> int:
return 32

Check warning on line 236 in src/model_signing/_hashing/gpu.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 236
33 changes: 33 additions & 0 deletions tests/_hashing/gpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Tests for GPU backed hashing engines."""

# Copyright 2024 The Sigstore Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from model_signing._hashing.gpu import TorchSHA256
from model_signing._hashing.memory import SHA256


pytest.importorskip("torch")


def test_torch_sha256_matches_hashlib():
data = b"sigstore"
gpu_hasher = TorchSHA256(data)
cpu_hasher = SHA256(data)
gpu_digest = gpu_hasher.compute()
cpu_digest = cpu_hasher.compute()
assert gpu_digest.digest_value == cpu_digest.digest_value
assert gpu_digest.algorithm == cpu_digest.algorithm == "sha256"
Loading