В настоящее время я использую следующий код:
callbacks = [
EarlyStopping(monitor='val_loss', patience=2, verbose=0),
ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
callbacks=callbacks)
Он говорит Керасу прекратить тренировку, если потери не улучшились в течение 2 эпох. Но я хочу прекратить тренировку после того, как потеря стала меньше некоторого постоянного «THR»:
if val_loss < THR:
break
Я видел в документации, что есть возможность сделать свой обратный вызов: http://keras.io/callbacks/ Но ничего не нашел, как остановить процесс обучения. Мне нужен совет.
from keras.callbacks import Callback
Обратный вызов keras.callbacks.EarlyStopping имеет аргумент min_delta. Из документации Keras:
источник
min_delta
сохранялось в течение нескольких эпох?Одним из решений является вызов
model.fit(nb_epoch=1, ...)
внутри цикла for, затем вы можете поместить оператор break внутри цикла for и выполнить любой другой настраиваемый поток управления, который хотите.источник
Я решил ту же проблему, используя собственный обратный вызов.
В следующем пользовательском коде обратного вызова присвойте THR значение, при котором вы хотите остановить обучение, и добавьте обратный вызов в свою модель.
from keras.callbacks import Callback class stopAtLossValue(Callback): def on_batch_end(self, batch, logs={}): THR = 0.03 #Assign THR with the value at which you want to stop training. if logs.get('loss') <= THR: self.model.stop_training = True
источник
Пока я практиковал специализацию TensorFlow , я изучил очень элегантную технику. Просто немного изменен из принятого ответа.
Давайте рассмотрим пример с нашими любимыми данными MNIST.
import tensorflow as tf class new_callback(tf.keras.callbacks.Callback): def epoch_end(self, epoch, logs={}): if(logs.get('accuracy')> 0.90): # select the accuracy print("\n !!! 90% accuracy, no further training !!!") self.model.stop_training = True mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize callbacks = new_callback() # model = tf.keras.models.Sequential([# define your model here]) model.compile(optimizer=tf.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
Итак, здесь я установил
metrics=['accuracy']
, и, таким образом, в классе обратного вызова условие установлено на'accuracy'> 0.90
.Вы можете выбрать любую метрику и следить за тренировкой, как в этом примере. Самое главное, что вы можете установить разные условия для разных метрик и использовать их одновременно.
Надеюсь, это поможет!
источник
Для меня модель остановит обучение только в том случае, если я добавлю оператор return после установки для параметра stop_training значения True, потому что я звонил после self.model.evaluate. Поэтому либо убедитесь, что в конце функции поставили stop_training = True, либо добавьте оператор возврата.
def on_epoch_end(self, batch, logs): self.epoch += 1 self.stoppingCounter += 1 print('\nstopping counter \n',self.stoppingCounter) #Stop training if there hasn't been any improvement in 'Patience' epochs if self.stoppingCounter >= self.patience: self.model.stop_training = True return # Test on additional set if there is one if self.testingOnAdditionalSet: evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0) self.validationLoss2.append(evaluation[0]) self.validationAcc2.append(evaluation[1])enter code here
источник
Если вы используете настраиваемый цикл обучения, вы можете использовать a
collections.deque
, который представляет собой «скользящий» список, который можно добавлять, и левые элементы выскакивают, когда список длиннее, чемmaxlen
. Вот строчка:loss_history = deque(maxlen=early_stopping + 1) for epoch in range(epochs): fit(epoch) loss_history.append(test_loss.result().numpy()) if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history) break
Вот полный пример:
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow_datasets as tfds import tensorflow as tf from tensorflow.keras.layers import Dense from collections import deque data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True) data = data.map(lambda x, y: (tf.cast(x, tf.int32), y)) train_dataset = data.take(120).batch(4) test_dataset = data.skip(120).take(30).batch(4) model = tf.keras.models.Sequential([ Dense(8, activation='relu'), Dense(16, activation='relu'), Dense(info.features['label'].num_classes)]) loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) train_loss = tf.keras.metrics.Mean() test_loss = tf.keras.metrics.Mean() train_acc = tf.keras.metrics.SparseCategoricalAccuracy() test_acc = tf.keras.metrics.SparseCategoricalAccuracy() opt = tf.keras.optimizers.Adam(learning_rate=1e-3) @tf.function def train_step(inputs, labels): with tf.GradientTape() as tape: logits = model(inputs, training=True) loss = loss_object(labels, logits) gradients = tape.gradient(loss, model.trainable_variables) opt.apply_gradients(zip(gradients, model.trainable_variables)) train_loss(loss) train_acc(labels, logits) @tf.function def test_step(inputs, labels): logits = model(inputs, training=False) loss = loss_object(labels, logits) test_loss(loss) test_acc(labels, logits) def fit(epoch): template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \ 'Train Acc {:.2f} Test Acc {:.2f}' train_loss.reset_states() test_loss.reset_states() train_acc.reset_states() test_acc.reset_states() for X_train, y_train in train_dataset: train_step(X_train, y_train) for X_test, y_test in test_dataset: test_step(X_test, y_test) print(template.format( epoch + 1, train_loss.result(), test_loss.result(), train_acc.result(), test_acc.result() )) def main(epochs=50, early_stopping=10): loss_history = deque(maxlen=early_stopping + 1) for epoch in range(epochs): fit(epoch) loss_history.append(test_loss.result().numpy()) if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history): print(f'\nEarly stopping. No validation loss ' f'improvement in {early_stopping} epochs.') break if __name__ == '__main__': main(epochs=250, early_stopping=10)
Epoch 1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33 Epoch 2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33 Epoch 3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33 Epoch 4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33 Epoch 5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33 ... Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87 Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87 Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87 Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87 Early stopping. No validation loss improvement in 10 epochs.
источник