@@ -81,6 +81,15 @@ def create_tf_data(self) -> tf.data.Dataset:
8181def get_model_train_details (cfg : Config , database : PilotDatabase ,
8282 model : str = None , model_type : str = None ) \
8383 -> Tuple [str , int , str , bool ]:
84+ """
85+ Returns automatic model name if none is given
86+ :param cfg: donkey config
87+ :param database: model database with existing training data
88+ :param model: model path
89+ :param model_type: type of model, like 'linear', 'tflite_linear', etc
90+ :return: tuple of model path, number, training type, and if
91+ tflite is requested
92+ """
8493 if not model_type :
8594 model_type = cfg .DEFAULT_MODEL_TYPE
8695 train_type = model_type
@@ -90,12 +99,13 @@ def get_model_train_details(cfg: Config, database: PilotDatabase,
9099 is_tflite = True
91100 model_num = 0
92101 if not model :
93- model_name , model_num = database .generate_model_name ()
102+ model_path , model_num = database .generate_model_name ()
94103 else :
95- model_name , model_ext = os .path .splitext (model )
104+ _ , model_ext = os .path .splitext (model )
105+ model_path = model
96106 is_tflite = model_ext == '.tflite'
97107
98- return model_name , model_num , train_type , is_tflite
108+ return model_path , model_num , train_type , is_tflite
99109
100110
101111def train (cfg : Config , tub_paths : str , model : str = None ,
@@ -105,10 +115,9 @@ def train(cfg: Config, tub_paths: str, model: str = None,
105115 Train the model
106116 """
107117 database = PilotDatabase (cfg )
108- model_name , model_num , train_type , is_tflite = \
118+ model_path , model_num , train_type , is_tflite = \
109119 get_model_train_details (cfg , database , model , model_type )
110120
111- output_path = os .path .join (cfg .MODELS_PATH , model_name + '.h5' )
112121 kl = get_model_by_type (train_type , cfg )
113122 if transfer :
114123 kl .load (transfer )
@@ -135,7 +144,7 @@ def train(cfg: Config, tub_paths: str, model: str = None,
135144 assert val_size > 0 , "Not enough validation data, decrease the batch " \
136145 "size or add more data."
137146
138- history = kl .train (model_path = output_path ,
147+ history = kl .train (model_path = model_path ,
139148 train_data = dataset_train ,
140149 train_steps = train_size ,
141150 batch_size = cfg .BATCH_SIZE ,
@@ -146,14 +155,14 @@ def train(cfg: Config, tub_paths: str, model: str = None,
146155 min_delta = cfg .MIN_DELTA ,
147156 patience = cfg .EARLY_STOP_PATIENCE ,
148157 show_plot = cfg .SHOW_PLOT )
149-
158+ base_path = os . path . splitext ( model_path )[ 0 ]
150159 if is_tflite :
151- tf_lite_model_path = f'{ os . path . splitext ( output_path )[ 0 ] } .tflite'
152- keras_model_to_tflite (output_path , tf_lite_model_path )
160+ tf_lite_model_path = f'{ base_path } .tflite'
161+ keras_model_to_tflite (model_path , tf_lite_model_path )
153162
154163 database_entry = {
155164 'Number' : model_num ,
156- 'Name' : model_name ,
165+ 'Name' : os . path . basename ( base_path ) ,
157166 'Type' : str (kl ),
158167 'Tubs' : tub_paths ,
159168 'Time' : time (),
0 commit comments