Skip to content

Commit 81c2ea3

Browse files
authored
* Bump version -> 4.2.1 (#881)
* Revert to passing full path to model in training call which got accidentally broken in 4.2 master.
1 parent 5c34d23 commit 81c2ea3

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

donkeycar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from pyfiglet import Figlet
33

4-
__version__ = '4.2.0'
4+
__version__ = '4.2.1'
55
f = Figlet(font='speed')
66

77
print(f.renderText('Donkey Car'))

donkeycar/pipeline/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def generate_model_name(self) -> Tuple[str, int]:
3333
else:
3434
this_num = 0
3535
date = time.strftime('%y-%m-%d')
36-
name = 'pilot_' + date + '_' + str(this_num)
37-
return name, this_num
36+
name = f'pilot_{date}_{this_num}.h5'
37+
return os.path.join(self.cfg.MODELS_PATH, name), this_num
3838

3939
def to_df(self) -> pd.DataFrame:
4040
if self.entries:

donkeycar/pipeline/training.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def create_tf_data(self) -> tf.data.Dataset:
8181
def get_model_train_details(cfg: Config, database: PilotDatabase,
8282
model: str = None, model_type: str = None) \
8383
-> Tuple[str, int, str, bool]:
84+
"""
85+
Returns automatic model name if none is given
86+
:param cfg: donkey config
87+
:param database: model database with existing training data
88+
:param model: model path
89+
:param model_type: type of model, like 'linear', 'tflite_linear', etc
90+
:return: tuple of model path, number, training type, and if
91+
tflite is requested
92+
"""
8493
if not model_type:
8594
model_type = cfg.DEFAULT_MODEL_TYPE
8695
train_type = model_type
@@ -90,12 +99,13 @@ def get_model_train_details(cfg: Config, database: PilotDatabase,
9099
is_tflite = True
91100
model_num = 0
92101
if not model:
93-
model_name, model_num = database.generate_model_name()
102+
model_path, model_num = database.generate_model_name()
94103
else:
95-
model_name, model_ext = os.path.splitext(model)
104+
_, model_ext = os.path.splitext(model)
105+
model_path = model
96106
is_tflite = model_ext == '.tflite'
97107

98-
return model_name, model_num, train_type, is_tflite
108+
return model_path, model_num, train_type, is_tflite
99109

100110

101111
def train(cfg: Config, tub_paths: str, model: str = None,
@@ -105,10 +115,9 @@ def train(cfg: Config, tub_paths: str, model: str = None,
105115
Train the model
106116
"""
107117
database = PilotDatabase(cfg)
108-
model_name, model_num, train_type, is_tflite = \
118+
model_path, model_num, train_type, is_tflite = \
109119
get_model_train_details(cfg, database, model, model_type)
110120

111-
output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5')
112121
kl = get_model_by_type(train_type, cfg)
113122
if transfer:
114123
kl.load(transfer)
@@ -135,7 +144,7 @@ def train(cfg: Config, tub_paths: str, model: str = None,
135144
assert val_size > 0, "Not enough validation data, decrease the batch " \
136145
"size or add more data."
137146

138-
history = kl.train(model_path=output_path,
147+
history = kl.train(model_path=model_path,
139148
train_data=dataset_train,
140149
train_steps=train_size,
141150
batch_size=cfg.BATCH_SIZE,
@@ -146,14 +155,14 @@ def train(cfg: Config, tub_paths: str, model: str = None,
146155
min_delta=cfg.MIN_DELTA,
147156
patience=cfg.EARLY_STOP_PATIENCE,
148157
show_plot=cfg.SHOW_PLOT)
149-
158+
base_path = os.path.splitext(model_path)[0]
150159
if is_tflite:
151-
tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite'
152-
keras_model_to_tflite(output_path, tf_lite_model_path)
160+
tf_lite_model_path = f'{base_path}.tflite'
161+
keras_model_to_tflite(model_path, tf_lite_model_path)
153162

154163
database_entry = {
155164
'Number': model_num,
156-
'Name': model_name,
165+
'Name': os.path.basename(base_path),
157166
'Type': str(kl),
158167
'Tubs': tub_paths,
159168
'Time': time(),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def package_files(directory, strip_leading):
2424
long_description = fh.read()
2525

2626
setup(name='donkeycar',
27-
version='4.2.0',
27+
version='4.2.1',
2828
long_description=long_description,
2929
description='Self driving library for python.',
3030
url='https://github.com/autorope/donkeycar',

0 commit comments

Comments
 (0)