Tensorflow Snapshot Ensembling

Snapshot Ensemble implementation

class SnapshotEnsemble(Callback):
    # this callback applies cosine annealing, saves snapshots and allows to load them
    __snapshot_name_fmt = "./snapshot_%d.hdf5"
    epoch_cycle = 0
    
    def __init__(self, n_models, n_epochs_per_model, lr_max, rampup_epoch,decay, verbose=1):
        """
        n_models -- quantity of models (snapshots)
        n_epochs_per_model -- quantity of epoch for every model (snapshot)
        lr_max -- maximum learning rate to ramp up towards
        rampup_epoch -- warm up time
        decay -- learning rate decay after rampup
        """
        self.n_epochs_per_model = n_epochs_per_model
        self.n_models = n_models
        self.n_epochs_total = self.n_models * self.n_epochs_per_model
        self.lr_max = lr_max
        self.rampup_epoch = rampup_epoch
        self.decay = decay
        self.verbose = verbose
        self.lrs = []
        self.history_temp = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}
        self.best_validation_accuracy = 0.0

    # calculate learning rate for epoch
    def cosine_annealing(self, epoch):
        cos_inner = (math.pi * (epoch % self.n_epochs_per_model)) / self.n_epochs_per_model
        
        if cos_inner == 0.0:
            self.epoch_cycle = 0
        
        self.epoch_cycle +=1
        if self.epoch_cycle <= self.rampup_epoch:
            return self.lr_max
        else:
            return self.lr_max * math.exp(-self.decay * self.epoch_cycle)
            
    # when epoch begins update learning rate
    def on_epoch_begin(self, epoch, logs={}):
        # update learning rate
        lr = self.cosine_annealing(epoch)
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        # log value
        self.lrs.append(lr)

Usage

snapshot_ensemble = SnapshotEnsemble(n_models= 3, 
                                     n_epochs_per_model=20,
                                     lr_max=1e-2,#High initial ramp up 
                                     rampup_epoch=5,
                                     decay=0.1)

history = model.fit(ds_train, validation_data=ds_valid,
                    epochs=snapshot_ensemble.n_epochs_total,
                    steps_per_epoch=int(len(train_data) / BATCH_SIZE),
                    validation_steps=int(len(valid_data) / BATCH_SIZE),
                    verbose=1,
                    callbacks=[snapshot_ensemble])

Written by@Ryan Liwag
Data scientist who dabbles in Machine learning and software engineering. If your interested in working with me, drop me an email at rjhontomin@gmail.com