Trang chủ‎ > ‎IT‎ > ‎Data Science - Python‎ > ‎Tensorflow‎ > ‎

Early Stopping implementation

In machine learningearly stopping is a form of regularization used to avoid overfitting when training a learner with an iterative method, such as gradient descent. Such methods update the learner so as to make it better fit the training data with each iteration. Up to a point, this improves the learner's performance on data outside of the training set. Past that point, however, improving the learner's fit to the training data comes at the expense of increased generalization error. Early stopping rules provide guidance as to how many iterations can be run before the learner begins to over-fit. Early stopping rules have been employed in many different machine learning methods, with varying amounts of theoretical foundation.

Validation-based early stopping[edit]

These early stopping rules work by splitting the original training set into a new training set and a validation set. The error on the validation set is used as a proxy for the generalization error in determining when overfitting has begun. These methods are most commonly employed in the training of neural networks. Prechelt gives the following summary of a naive implementation of holdout-based early stopping as follows:[8]

  1. Split the training data into a training set and a validation set, e.g. in a 2-to-1 proportion.
  2. Train only on the training set and evaluate the per-example error on the validation set once in a while, e.g. after every fifth epoch.
  3. Stop training as soon as the error on the validation set is higher than it was the last time it was checked.
  4. Use the weights the network had in that previous step as the result of the training run.
    — Lutz Prechelt, Early Stopping – But When?

More sophisticated forms use cross-validation – multiple partitions of the data into training set and validation set – instead of a single partition into a training set and validation set. Even this simple procedure is complicated in practice by the fact that the validation error may fluctuate during training, producing multiple local minima. This complication has led to the creation of many ad-hoc rules for deciding when overfitting has truly begun.[8]


Paper: Early Stopping, but whenhttp://page.mi.fu-berlin.de/prechelt/Biblio/stop_tricks1997.pdf


Early Stopping with Deep4j library

When training neural networks, numerous decisions need to be made regarding the settings (hyperparameters) used, in order to obtain good performance. Once such hyperparameter is the number of training epochs: that is, how many full passes of the data set (epochs) should be used? If we use too few epochs, we might underfit (i.e., not learn everything we can from the training data); if we use too many epochs, we might overfit (i.e., fit the ‘noise’ in the training data, and not the signal).

Early stopping attempts to remove the need to manually set this value. It can also be considered a type of regularization method (like L1/L2 weight decay and dropout) in that it can stop the network from overfitting.

The idea behind early stopping is relatively simple:

  • Split data into training and test sets
  • At the end of each epoch (or, every N epochs):
    • evaluate the network performance on the test set
    • if the network outperforms the previous best model: save a copy of the network at the current epoch
  • Take as our final model the model that has the best test set performance

This is shown graphically below:

Early Stopping

The best model is the one saved at the time of the vertical dotted line - i.e., the model with the best accuracy on the test set.

Using DL4J’s early stopping functionality requires you to provide a number of configuration options:

  • A score calculator, such as the DataSetLossCalculator(JavaDocSource Code) for a Multi Layer Network, or DataSetLossCalculatorCG (JavaDocSource Code) for a Computation Graph. Is used to calculate at every epoch (for example: the loss function value on a test set, or the accuracy on the test set)
  • How frequently we want to calculate the score function (default: every epoch)
  • One or more termination conditions, which tell the training process when to stop. There are two classes of termination conditions:
    • Epoch termination conditions: evaluated every N epochs
    • Iteration termination conditions: evaluated once per minibatch
  • A model saver, that defines how models are saved (see: LocalFileModelSaver JavaDocLocalFileModelSaver Source Code and InMemoryModelSaver JavaDocInMemoryModelSaver Source Code )

An example, with an epoch termination condition of maximum of 30 epochs, a maximum of 20 minutes training time, calculating the score every epoch, and saving the intermediate results to disk:

CopyMultiLayerConfiguration myNetworkConfiguration = ...;
DataSetIterator myTrainData = ...;
DataSetIterator myTestData = ...;

EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
		.epochTerminationConditions(new MaxEpochsTerminationCondition(30))
		.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
		.scoreCalculator(new DataSetLossCalculator(myTestData, true))
        .evaluateEveryNEpochs(1)
		.modelSaver(new LocalFileModelSaver(directory))
		.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf,myNetworkConfiguration,myTrainData);

//Conduct early stopping training:
EarlyStoppingResult result = trainer.fit();

//Print out the results:
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());

//Get the best model:
MultiLayerNetwork bestModel = result.getBestModel();

Examples of epoch termination conditions:

Examples of iteration terminations conditions:

  • To terminate training after a specified amount of time (without waiting for an epoch to complete), use MaxTimeIterationTerminationCondition
  • To terminate training if the score exceeds a certain value at any point, use MaxScoreIterationTerminationCondition. This can be useful for example to terminate the training immediately if the network is poorly tuned or training becomes unstable (such as exploding weights/scores).

The source code for the built in termination classes are in this directory

You can of course implement your own iteration and epoch termination conditions.

Early Stopping w/ Parallel Wrapper

The early stopping implementation described above will only work with a single device. However, EarlyStoppingParallelTrainer provides similar functionality as early stopping and allows you to optimize for either multiple CPUs or GPUs. EarlyStoppingParallelTrainerwraps your model in a ParallelWrapper class and performs localized distributed training.

Note that EarlyStoppingParallelTrainer doesn’t support all of the functionality as its single device counterpart. It is not UI-compatible and may not work with complex iteration listeners. This is due to how the model is distributed and copied in the background.

TestParallelEarlyStopping.java gives a good example of setting up parallel early stopping in different scenarios.

Final notes


Comments