Skip to content

Commit 919f212

Browse files
committed
Don't always import tensorflow of pytorch in TMVA_CNN_Classification.py
Don't import tensorflow of pytorch for feature detection TMVA_CNN_Classification.py. The C++ version of the tutorial also doesn't do it, and importing TensorFlow can have bad consequences like symbol collisions with the system OpenBLAS.
1 parent 35f6c10 commit 919f212

File tree

1 file changed

+1
-24
lines changed

1 file changed

+1
-24
lines changed

tutorials/tmva/TMVA_CNN_Classification.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,13 @@
2727
import os
2828
import importlib.util
2929

30-
useKerasCNN = False
31-
32-
if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes":
33-
useKerasCNN = True
34-
3530
opt = [1, 1, 1, 1, 1]
3631
useTMVACNN = opt[0] if len(opt) > 0 else False
37-
useKerasCNN = opt[1] if len(opt) > 1 else useKerasCNN
32+
useKerasCNN = opt[1] if len(opt) > 1 else False
3833
useTMVADNN = opt[2] if len(opt) > 2 else False
3934
useTMVABDT = opt[3] if len(opt) > 3 else False
4035
usePyTorchCNN = opt[4] if len(opt) > 4 else False
4136

42-
if useKerasCNN:
43-
import tensorflow
44-
45-
# PyTorch has to be imported before ROOT to avoid crashes because of clashing
46-
# std::regexp symbols that are exported by cppyy.
47-
# See also: https://github.com/wlav/cppyy/issues/227
48-
torch_spec = importlib.util.find_spec("torch")
49-
if torch_spec is None:
50-
usePyTorchCNN = False
51-
print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
52-
else:
53-
import torch
54-
55-
56-
import ROOT
57-
58-
#switch off MT in OpenMP (BLAS)
59-
6037
TMVA = ROOT.TMVA
6138
TFile = ROOT.TFile
6239

0 commit comments

Comments
 (0)