2222import json
2323import os
2424
25- os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'expandable_segments:True'
26- os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "2" # Disables tensorRT, cuda warnings.
27- # disable only for deepspeech if it works fine for other workloads.
28- os .environ ['XLA_FLAGS' ] = '--xla_gpu_enable_triton_gemm=false'
29-
3025import struct
3126import time
3227from types import MappingProxyType
5651from algorithmic_efficiency .pytorch_utils import sync_ddp_time
5752from algorithmic_efficiency .workloads import workloads
5853
54+ # Environment variables
55+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "2" # Disables tensorRT, cuda warnings.
56+ # disable only for deepspeech if it works fine for other workloads
57+ os .environ ['XLA_FLAGS' ] = '--xla_gpu_enable_triton_gemm=false'
58+
5959# TODO(znado): make a nicer registry of workloads that lookup in.
6060BASE_WORKLOADS_DIR = workloads .BASE_WORKLOADS_DIR
6161
@@ -681,6 +681,14 @@ def main(_):
681681 else :
682682 profiler = PassThroughProfiler ()
683683
684+ # Set PyTorch environment variables before initializing w DDP
685+ base_workload = workloads .get_base_workload_name (FLAGS .workload )
686+ if base_workload == 'librispeech_conformer' :
687+ os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'expandable_segments:True'
688+
689+ if FLAGS .set_pytorch_max_split_size :
690+ os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'max_split_size_mb:256'
691+
684692 if FLAGS .framework == 'pytorch' :
685693 pytorch_init (USE_PYTORCH_DDP , RANK , profiler )
686694
@@ -692,9 +700,6 @@ def main(_):
692700
693701 workload_metadata = WORKLOADS [FLAGS .workload ]
694702
695- # Prevent OOM on librispeech conformer.
696- base_workload = workloads .get_base_workload_name (FLAGS .workload )
697-
698703 if base_workload in [
699704 'librispeech_conformer' ,
700705 'librispeech_deepspeech' ,
@@ -703,13 +708,6 @@ def main(_):
703708 ]:
704709 os .environ ['XLA_PYTHON_CLIENT_MEM_FRACTION' ] = '0.80'
705710
706- if base_workload != 'librispeech_conformer' :
707- # Remove the environment variable (only for workloads other than librispeech conformer).
708- del os .environ ['PYTORCH_CUDA_ALLOC_CONF' ]
709-
710- if FLAGS .set_pytorch_max_split_size :
711- os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'max_split_size_mb:256'
712-
713711 # Extend path according to framework.
714712 workload_metadata ['workload_path' ] = os .path .join (
715713 BASE_WORKLOADS_DIR ,
0 commit comments