Я тренирую нейронную сеть для своего проекта с помощью Keras. Керас предусмотрел функцию ранней остановки. Могу ли я узнать, какие параметры следует соблюдать, чтобы моя нейронная сеть не переобучалась с помощью ранней остановки?
источник
Я тренирую нейронную сеть для своего проекта с помощью Keras. Керас предусмотрел функцию ранней остановки. Могу ли я узнать, какие параметры следует соблюдать, чтобы моя нейронная сеть не переобучалась с помощью ранней остановки?
Ранняя остановка - это, по сути, прекращение тренировки, когда ваши потери начинают расти (или, другими словами, точность проверки начинает снижаться). Согласно документам он используется следующим образом;
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0, mode='auto')
Значения зависят от вашей реализации (проблема, размер партии и т. Д.), Но обычно для предотвращения переобучения я бы использовал;
monitor
аргумента значение 'val_loss'
.min_delta
- это порог, определяющий, можно ли количественно оценить потерю в определенную эпоху как улучшение или нет. Если разница в убытках ниже min_delta
, это определяется как отсутствие улучшения. Лучше оставить его равным 0, так как нас интересует, когда убыток станет хуже.patience
Аргумент представляет количество эпох перед остановкой после того, как ваши убытки начнут расти (перестают улучшаться). Это зависит от вашей реализации, если вы используете очень маленькие партии
или большую скорость обучения, ваш зигзагообразный зигзаг потерь (точность будет более шумной), поэтому лучше установить большой patience
аргумент. Если вы используете большие партии и небольшую скорость обучения, ваши потери будут более плавными, поэтому вы можете использовать меньший patience
аргумент. В любом случае я оставлю это как 2, чтобы дать модели больше шансов.verbose
решает, что печатать, оставьте значение по умолчанию (0).mode
Аргумент зависит от того, в каком направлении находится ваше отслеживаемое количество (должно ли оно уменьшаться или увеличиваться), поскольку мы отслеживаем потери, которые мы можем использовать min
. Но давайте оставим keras обрабатывать это для нас и установим это наauto
Поэтому я бы использовал что-то подобное и поэкспериментировал, построив график потери ошибок с ранней остановкой и без нее.
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=2,
verbose=0, mode='auto')
Для возможной двусмысленности в том, как работают обратные вызовы, я постараюсь объяснить больше. Как только вы вызываете fit(... callbacks=[es])
свою модель, Keras вызывает предопределенные функции заданных объектов обратного вызова. Эти функции могут быть вызваны on_train_begin
, on_train_end
, on_epoch_begin
, on_epoch_end
и on_batch_begin
, on_batch_end
. Обратный вызов ранней остановки вызывается в конце каждой эпохи, сравнивает лучшее отслеживаемое значение с текущим и останавливается, если выполняются условия (сколько эпох прошло с момента наблюдения лучшего отслеживаемого значения, и является ли это более чем аргументом терпения, разница между последнее значение больше min_delta и т. д.).
Как указано в комментариях @BrentFaust, обучение модели будет продолжаться до тех пор, пока не будут выполнены условия ранней остановки или epochs
параметр (по умолчанию = 10) fit()
. Установка обратного вызова Early Stopping не заставит модель тренироваться сверх своего epochs
параметра. Таким образом, вызов fit()
функции с большим epochs
значением выиграет от обратного вызова Early Stopping.
min_delta
- это порог, определяющий, следует ли количественно оценивать изменение отслеживаемого значения как улучшение или нет. Так что да, если мы дадим,monitor = 'val_loss'
то это будет относиться к разнице между текущей потерей валидации и предыдущей потерей валидации. На практике, если вы дадитеmin_delta=0.1
уменьшение потерь при проверке (текущее - предыдущее) менее 0,1, это не будет количественно определено, поэтому обучение будет остановлено (если оно естьpatience = 0
).callbacks=[EarlyStopping(patience=2)]
это не имеет эффекта, если не указаны эпохиmodel.fit(..., epochs=max_epochs)
.epoch=1
в цикле for (для различных случаев использования), в котором этот обратный вызов не удастся. Если в моем ответе есть двусмысленность, я постараюсь выразить его лучше.restore_best_weights
аргумент (пока не указан в документации), который загружает модель с лучшими весами после обучения. Но для ваших целей я бы использовалModelCheckpoint
обратный вызов сsave_best_only
аргументом. Вы можете проверить документацию, она проста в использовании, но вам нужно вручную загрузить лучшие веса после тренировки.