Skip to content

Commit 909230b

Browse files
add a timing wrapper around model serialization
Signed-off-by: Spencer Schrock <[email protected]>
1 parent d7f8b57 commit 909230b

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

benchmarks/time_serialize.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The Sigstore Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Script for timing model serialization benchmarks."""
17+
18+
import argparse
19+
import json
20+
import sys
21+
import time
22+
23+
import cpuinfo
24+
import psutil
25+
import serialize
26+
27+
28+
def build_parser() -> argparse.ArgumentParser:
29+
"""Builds the command line parser to benchmark serializing models."""
30+
parser = argparse.ArgumentParser(description="model benchmark data")
31+
32+
parser.add_argument("path", help="path to model")
33+
34+
parser.add_argument(
35+
"--repeat",
36+
help="how many times to repeat each model",
37+
type=int,
38+
default=6,
39+
)
40+
41+
parser.add_argument("--output", "-o", help="path for result file")
42+
43+
return parser
44+
45+
46+
if __name__ == "__main__":
47+
args = build_parser().parse_args()
48+
49+
serialize_args = serialize.build_parser().parse_args(
50+
[args.path, "--use_shards"]
51+
)
52+
53+
results = dict()
54+
results["model"] = args.path
55+
results["ram"] = f"{psutil.virtual_memory().total / 1024**3} GiB"
56+
57+
times = list()
58+
for _ in range(args.repeat):
59+
st = time.time()
60+
payload = serialize.run(serialize_args)
61+
en = time.time()
62+
times.append(en - st)
63+
64+
results["times"] = times
65+
results["cpu"] = cpuinfo.get_cpu_info()
66+
67+
if args.output:
68+
with open(args.output, "w", encoding="utf-8") as f:
69+
json.dump(results, f, ensure_ascii=False, indent=4)
70+
else:
71+
json.dump(results, sys.stdout, ensure_ascii=False, indent=4)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ Use `hatch run +py=3... bench:chunk ${args}` to benchmark the chunk size paramet
8989
"""
9090
extra-dependencies = [
9191
"numpy",
92+
"psutil",
93+
"py-cpuinfo",
9294
]
9395

9496
[[tool.hatch.envs.bench.matrix]]
@@ -117,6 +119,7 @@ description = """Custom environment for pytype.
117119
Use `hatch run type:check` to check types.
118120
"""
119121
extra-dependencies = [
122+
"py-cpuinfo",
120123
"pytest",
121124
"pytype",
122125
]

0 commit comments

Comments
 (0)