Skip to content

Commit d507c51

Browse files
CaptainSameSameer Sharma
andauthored
MLCOMPUTE-1497 | add methods to get total driver memory including overhead (#146)
* MLCOMPUTE-1209 | add methods to get total driver memory including overhead * fix tests * MLCOMPUTE-1209 | bump up version --------- Co-authored-by: Sameer Sharma <[email protected]>
1 parent 9719604 commit d507c51

File tree

4 files changed

+143
-1
lines changed

4 files changed

+143
-1
lines changed

service_configuration_lib/spark_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ def _filter_user_spark_opts(user_spark_opts: Mapping[str, str]) -> MutableMappin
414414
}
415415

416416

417+
def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int:
418+
return int(utils.get_spark_driver_memory_mb(spark_conf) + utils.get_spark_driver_memory_overhead_mb(spark_conf))
419+
420+
417421
class SparkConfBuilder:
418422

419423
def __init__(self):

service_configuration_lib/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from socket import SO_REUSEADDR
1212
from socket import socket
1313
from socket import SOL_SOCKET
14+
from typing import Dict
1415
from typing import Mapping
1516
from typing import Tuple
1617

1718
import yaml
19+
from typing_extensions import Literal
1820

1921
DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
2022
POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml'
@@ -24,6 +26,11 @@
2426
EPHEMERAL_PORT_START = 49152
2527
EPHEMERAL_PORT_END = 65535
2628

29+
MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4}
30+
31+
SPARK_DRIVER_MEM_DEFAULT_MB = 2048
32+
SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT = 0.1
33+
2734

2835
log = logging.Logger(__name__)
2936
log.setLevel(logging.INFO)
@@ -148,3 +155,56 @@ def get_runtime_env() -> str:
148155
# we could also just crash or return None, but this seems a little easier to find
149156
# should we somehow run into this at Yelp
150157
return 'unknown'
158+
159+
160+
def get_spark_memory_in_unit(mem: str, unit: Literal['k', 'm', 'g', 't']) -> float:
161+
"""
162+
Converts Spark memory to the desired unit.
163+
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
164+
unit can be 'k', 'm', 'g' or 't'.
165+
Returns memory as a float converted to the desired unit.
166+
"""
167+
try:
168+
memory_bytes = float(mem)
169+
except ValueError:
170+
try:
171+
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
172+
except (ValueError, IndexError):
173+
print(f'Unable to parse memory value {mem}.')
174+
raise
175+
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
176+
return round(memory_unit, 5)
177+
178+
179+
def get_spark_driver_memory_mb(spark_conf: Dict[str, str]) -> float:
180+
"""
181+
Returns the Spark driver memory in MB.
182+
"""
183+
# spark_conf is expected to have "spark.driver.memory" since it is a mandatory default from srv-configs.
184+
driver_mem = spark_conf['spark.driver.memory']
185+
try:
186+
return get_spark_memory_in_unit(str(driver_mem), 'm')
187+
except (ValueError, IndexError):
188+
return SPARK_DRIVER_MEM_DEFAULT_MB
189+
190+
191+
def get_spark_driver_memory_overhead_mb(spark_conf: Dict[str, str]) -> float:
192+
"""
193+
Returns the Spark driver memory overhead in bytes.
194+
"""
195+
# Use spark.driver.memoryOverhead if it is set.
196+
try:
197+
driver_mem_overhead = spark_conf['spark.driver.memoryOverhead']
198+
try:
199+
# spark.driver.memoryOverhead default unit is MB
200+
driver_mem_overhead_mb = float(driver_mem_overhead)
201+
except ValueError:
202+
driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), 'm')
203+
# Calculate spark.driver.memoryOverhead based on spark.driver.memory and spark.driver.memoryOverheadFactor.
204+
except Exception:
205+
driver_mem_mb = get_spark_driver_memory_mb(spark_conf)
206+
driver_mem_overhead_factor = float(
207+
spark_conf.get('spark.driver.memoryOverheadFactor', SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT),
208+
)
209+
driver_mem_overhead_mb = driver_mem_mb * driver_mem_overhead_factor
210+
return round(driver_mem_overhead_mb, 5)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
setup(
1919
name='service-configuration-lib',
20-
version='2.18.19',
20+
version='2.18.20',
2121
provides=['service_configuration_lib'],
2222
description='Start, stop, and inspect Yelp SOA services',
2323
url='https://github.com/Yelp/service_configuration_lib',

tests/utils_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from socket import SO_REUSEADDR
33
from socket import socket as Socket
44
from socket import SOL_SOCKET
5+
from typing import cast
56
from unittest import mock
67
from unittest.mock import mock_open
78
from unittest.mock import patch
89

910
import pytest
11+
from typing_extensions import Literal
1012

1113
from service_configuration_lib import utils
1214
from service_configuration_lib.utils import ephemeral_port_reserve_range
@@ -74,6 +76,82 @@ def test_generate_pod_template_path(hex_value):
7476
assert utils.generate_pod_template_path() == f'/nail/tmp/spark-pt-{hex_value}.yaml'
7577

7678

79+
@pytest.mark.parametrize(
80+
'mem_str,unit_str,expected_mem',
81+
(
82+
('13425m', 'm', 13425), # Simple case
83+
('138412032', 'm', 132), # Bytes to MB
84+
('65536k', 'g', 0.0625), # KB to GB
85+
('1t', 'g', 1024), # TB to GB
86+
('1.5g', 'm', 1536), # GB to MB with decimal
87+
('2048k', 'm', 2), # KB to MB
88+
('0.5g', 'k', 524288), # GB to KB
89+
('32768m', 't', 0.03125), # MB to TB
90+
('1.5t', 'm', 1572864), # TB to MB with decimal
91+
),
92+
)
93+
def test_get_spark_memory_in_unit(mem_str, unit_str, expected_mem):
94+
assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))
95+
96+
97+
@pytest.mark.parametrize(
98+
'mem_str,unit_str',
99+
[
100+
('invalid', 'm'),
101+
('1024mb', 'g'),
102+
],
103+
)
104+
def test_get_spark_memory_in_unit_exceptions(mem_str, unit_str):
105+
with pytest.raises((ValueError, IndexError)):
106+
utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))
107+
108+
109+
@pytest.mark.parametrize(
110+
'spark_conf,expected_mem',
111+
[
112+
({'spark.driver.memory': '13425m'}, 13425), # Simple case
113+
({'spark.driver.memory': '138412032'}, 132), # Bytes to MB
114+
({'spark.driver.memory': '65536k'}, 64), # KB to MB
115+
({'spark.driver.memory': '1g'}, 1024), # GB to MB
116+
({'spark.driver.memory': 'invalid'}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case
117+
({'spark.driver.memory': '1.5g'}, 1536), # GB to MB with decimal
118+
({'spark.driver.memory': '2048k'}, 2), # KB to MB
119+
({'spark.driver.memory': '0.5t'}, 524288), # TB to MB
120+
({'spark.driver.memory': '1024m'}, 1024), # MB to MB
121+
({'spark.driver.memory': '1.5t'}, 1572864), # TB to MB with decimal
122+
],
123+
)
124+
def test_get_spark_driver_memory_mb(spark_conf, expected_mem):
125+
assert expected_mem == utils.get_spark_driver_memory_mb(spark_conf)
126+
127+
128+
@pytest.mark.parametrize(
129+
'spark_conf,expected_mem_overhead',
130+
[
131+
({'spark.driver.memoryOverhead': '1024'}, 1024), # Simple case
132+
({'spark.driver.memoryOverhead': '1g'}, 1024), # GB to MB
133+
({'spark.driver.memory': '10240m', 'spark.driver.memoryOverheadFactor': '0.2'}, 2048), # Custom OverheadFactor
134+
({'spark.driver.memory': '10240m'}, 1024), # Using default overhead factor
135+
(
136+
{'spark.driver.memory': 'invalid'},
137+
utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT,
138+
),
139+
# Invalid case
140+
({'spark.driver.memoryOverhead': '1.5g'}, 1536), # GB to MB with decimal
141+
({'spark.driver.memory': '2048k', 'spark.driver.memoryOverheadFactor': '0.05'}, 0.1),
142+
# KB to MB with custom factor
143+
({'spark.driver.memory': '0.5t', 'spark.driver.memoryOverheadFactor': '0.15'}, 78643.2),
144+
# TB to MB with custom factor
145+
({'spark.driver.memory': '1024m', 'spark.driver.memoryOverheadFactor': '0.25'}, 256),
146+
# MB to MB with custom factor
147+
({'spark.driver.memory': '1.5t', 'spark.driver.memoryOverheadFactor': '0.05'}, 78643.2),
148+
# TB to MB with custom factor
149+
],
150+
)
151+
def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead):
152+
assert expected_mem_overhead == utils.get_spark_driver_memory_overhead_mb(spark_conf)
153+
154+
77155
@pytest.fixture
78156
def mock_runtimeenv():
79157
with patch('builtins.open', mock_open(read_data=MOCK_ENV_NAME)) as m:

0 commit comments

Comments
 (0)