Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,21 @@ Classification performance compared to standard Keras LSTM for MNIST dataset:
:width: 100px
:alt: Loss [red: PLSTM, black: LSTM]
:align: center

Using the timegate with frozen weights
================

* Creating an initializer for the timegate:
.. code-block:: python

# Opening the gate every 8 timesteps
def timegate_init(shape, dtype=None):
return K.constant(np.vstack((
np.zeros(shape[1]) + 0.8, # period
np.zeros(shape[1]) + 0.01, # shift
np.zeros(shape[1]) + 0.05)), dtype=dtype) # ratio

* Setting the `timegate_initializer` and marking the `trainable_timegame` as `False`:
.. code-block:: python

PhasedLSTM(150, return_sequences=True, timegate_initializer=timegate_init, trainable_timegate=False)
37 changes: 18 additions & 19 deletions phased_lstm_keras/PhasedLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from keras.engine import InputSpec
from keras.legacy import interfaces
from keras.layers import Recurrent
from keras.utils.generic_utils import get_custom_objects

def _time_distributed_dense(x, w, b=None, dropout=None,
input_dim=None, output_dim=None,
Expand Down Expand Up @@ -141,6 +140,7 @@ def __init__(self, units,
recurrent_constraint=None,
bias_constraint=None,
timegate_constraint='non_neg',
trainable_timegate=True,
dropout=0.,
recurrent_dropout=0.,
alpha=0.001,
Expand Down Expand Up @@ -171,6 +171,7 @@ def __init__(self, units,
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.alpha = alpha
self.trainable_timegate = trainable_timegate

def build(self, input_shape):
if isinstance(input_shape, list):
Expand All @@ -186,21 +187,20 @@ def build(self, input_shape):
if self.stateful:
self.reset_states()

self.kernel = self.add_weight((self.input_dim, self.units * 4),
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)

self.recurrent_kernel = self.add_weight(
(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)

if self.use_bias:
self.bias = self.add_weight((self.units * 4,),
self.bias = self.add_weight(shape=(self.units * 4,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -234,12 +234,13 @@ def build(self, input_shape):
self.bias_o = None

# time-gate
self.timegate_kernel = self.add_weight(
(3, self.units),
name='timegate_kernel',
initializer=self.timegate_initializer,
regularizer=self.timegate_regularizer,
constraint=self.timegate_constraint)
self.timegate_kernel = self.add_weight(shape=(3, self.units),
name='timegate_kernel',
initializer=self.timegate_initializer,
regularizer=self.timegate_regularizer,
constraint=self.timegate_constraint,
trainable=self.trainable_timegate)

self.built = True

def preprocess_input(self, inputs, training=None):
Expand Down Expand Up @@ -297,6 +298,7 @@ def dropped_inputs():
return constants

def step(self, inputs, states):

h_tm1 = states[0]
c_tm1 = states[1]
t_tm1 = states[2]
Expand All @@ -317,7 +319,6 @@ def step(self, inputs, states):
# a mod n = a - (n * int(a/n))
# phi = ((t - shift) % period) / period
phi = ((t - shift) - (period * ((t - shift) // period))) / period

# K.switch not consistent between Theano and Tensorflow backend, so write explicitly.
up = K.cast(K.less_equal(phi, r_on * 0.5), K.floatx()) * 2 * phi / r_on
mid = K.cast(K.less_equal(phi, r_on), K.floatx()) * \
Expand Down Expand Up @@ -392,10 +393,8 @@ def get_config(self):
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint),
'trainable_timegate': self.trainable_timegate,
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout}
base_config = super(PhasedLSTM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


get_custom_objects().update({'PhasedLSTM': PhasedLSTM})