-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining_script.py
More file actions
74 lines (60 loc) · 2.64 KB
/
training_script.py
File metadata and controls
74 lines (60 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import tensorflow as tf
from models.definitions.transformer import *
from utils.constants import *
from utils.optimizer import *
from utils.loss_metrics import *
import argparse
from translation_script import *
from utils.export import *
def train_transformer(training_config):
"""
Train the transformer model
:param training_config: dictionary containing training configuration
:return: None
"""
transformer = Transformer(
num_layers=NUM_LAYERS,
d_model=MODEL_DIMENSION,
num_heads=NUMBER_OF_HEADS,
dff=DFF,
input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),
target_vocab_size=tokenizers.en.get_vocab_size().numpy(),
dropout_rate=DROPOUT_PROB)
print(transformer.summary())
learning_rate = CustomSchedule(MODEL_DIMENSION)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
epsilon=1e-9)
transformer.compile(
loss=masked_loss,
optimizer=optimizer,
metrics=[masked_accuracy])
transformer.fit(train_batches,
epochs=training_config['num_of_epochs'],
validation_data=val_batches)
translator = Translator(tokenizers, transformer)
translator = ExportTranslator(translator)
tf.saved_model.save(translator, export_dir='translator')
if __name__ == "__main__":
#
# Fixed args - don't change these unless you have a good reason
#
num_warmup_steps = 4000
#
# Modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering)
#
parser = argparse.ArgumentParser()
parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=20)
# You should adjust this for your particular machine
parser.add_argument("--batch_size", type=int, help="target number of tokens in a src/trg batch", default=1500)
# Logging/debugging/checkpoint related (helps a lot with experimentation)
parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=True)
parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq", default=10)
parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq", default=1)
args = parser.parse_args()
# Wrapping training configuration into a dictionary
training_config = dict()
for arg in vars(args):
training_config[arg] = getattr(args, arg)
training_config['num_warmup_steps'] = num_warmup_steps
# Train the original transformer model
train_transformer(training_config)