September 14, 2020
class SnapshotEnsemble(Callback):
# this callback applies cosine annealing, saves snapshots and allows to load them
__snapshot_name_fmt = "./snapshot_%d.hdf5"
epoch_cycle = 0
def __init__(self, n_models, n_epochs_per_model, lr_max, rampup_epoch,decay, verbose=1):
"""
n_models -- quantity of models (snapshots)
n_epochs_per_model -- quantity of epoch for every model (snapshot)
lr_max -- maximum learning rate to ramp up towards
rampup_epoch -- warm up time
decay -- learning rate decay after rampup
"""
self.n_epochs_per_model = n_epochs_per_model
self.n_models = n_models
self.n_epochs_total = self.n_models * self.n_epochs_per_model
self.lr_max = lr_max
self.rampup_epoch = rampup_epoch
self.decay = decay
self.verbose = verbose
self.lrs = []
self.history_temp = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}
self.best_validation_accuracy = 0.0
# calculate learning rate for epoch
def cosine_annealing(self, epoch):
cos_inner = (math.pi * (epoch % self.n_epochs_per_model)) / self.n_epochs_per_model
if cos_inner == 0.0:
self.epoch_cycle = 0
self.epoch_cycle +=1
if self.epoch_cycle <= self.rampup_epoch:
return self.lr_max
else:
return self.lr_max * math.exp(-self.decay * self.epoch_cycle)
# when epoch begins update learning rate
def on_epoch_begin(self, epoch, logs={}):
# update learning rate
lr = self.cosine_annealing(epoch)
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
# log value
self.lrs.append(lr)
snapshot_ensemble = SnapshotEnsemble(n_models= 3,
n_epochs_per_model=20,
lr_max=1e-2,#High initial ramp up
rampup_epoch=5,
decay=0.1)
history = model.fit(ds_train, validation_data=ds_valid,
epochs=snapshot_ensemble.n_epochs_total,
steps_per_epoch=int(len(train_data) / BATCH_SIZE),
validation_steps=int(len(valid_data) / BATCH_SIZE),
verbose=1,
callbacks=[snapshot_ensemble])