DL(Deep-Learning)/Tensorflow

[Keras] callback 함수 - EarlyStopping

AI 그게 뭔데 2022. 1. 25. 15:08

Tensorflow Keras의 EarlyStopping 콜백 함수를 활용하면, model의 성능 지표(acc, loss등)가 설정한 epoch동안 개선되지 않을 때 조기 종료할 수 있다.

 

아래의 공식 페이지에서 사용방법을 확인할 수 있다.

 

tf.keras.callbacks.EarlyStopping  |  TensorFlow Core v2.7.0

Stop training when a monitored metric has stopped improving.

www.tensorflow.org

 

 

문법

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0,
    mode='auto', baseline=None, restore_best_weights=False
)

 

 

✔ 인자 설명

인자 설명
monitor EarlyStopping 기준이 되는 값을 입력한다.

만약 'val_loss'를 입력하면 val_loss가 더이상 감소되지 않을 경우 EarlyStopping을 적용한다.
min_delta 개선된 것으로 간주하기 위한 최소한의 변화량

예를 들어, min_delta 0.01이고, 30에폭에 정확도가 0.888라고 할 때,
31
에폭에 정확도가 0.8889라고 하면 이는 0.001의 개선이 있었지만 min_delta 0.01에는 미치지 못했으므로 개선된 것으로 보지 않는다.
patience Training이 진행됨에도 더이상 monitor되는 값의 개선이 없을 경우, 최적의 monitor 값을 기준으로 몇 번의 epoch을 진행할 지 정하는 값

예를 들어 patience3이고, 30에폭에 정확도가 90%였을 때,
만약 31번째에 정확도 89%, 32번째에 89%, 33번째에 88%라면 더 이상 Training을 진행하지 않고 종료한다.
verbose 0 또는 1

1
일 경우, EarlyStopping이 적용될 때, 화면에 보여진다.
0
일 경우, 화면에 나타냄 없이 종료한다.
mode "auto" 또는 "min" 또는 "max"

monitor
되는 값이 최소가 되어야 하는지, 최대가 되어야 하는지 알려주는 인자
예를 들어, monitor하는 값이 val_acc 즉 정확도일 경우, 값이 클수록 좋기때문에 "max"를 입력하고, val_loss일 경우 작을수록 좋기 때문에 "min"을 입력한다.
"auto"
는 모델이 알아서 판단해준다.
baseline 모델이 달성해야하는 최소한의 기준값을 선정
patience
이내에 모델이 baseline보다 개선됨이 보이지 않으면 Training을 중단시킨다.

예를 들어 patience 3이고 baseline이 정확도기준 0.98 이라면,
3
번의 trianing안에 0.98의 정확도를 달성하지 못하면 Training이 종료된다.
restore_best_weights True, False

True
라면 training이 끝난 후, model weight monitor하고 있던 값이 가장 좋았을 때의 weight로 복원한다.
False
라면, 마지막 training이 끝난 후의 weight 그대로를 나타낸다.

 

 

출처: https://deep-deep-deep.tistory.com/55 [딥딥딥]

 

 

❗ Early stopping 사용시 주의사항

early stopping은 훈련을 언제 종료시킬지를 결정할 뿐이고, Best 성능을 갖는 모델을 저장하지는 않는다. 따라서 early stopping과 함께 모델을 저장하는 callback 함수를 반드시 활용해야 한다.

 

callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss',
                                           patience=5),
             keras.callbacks.ModelCheckpoint(filepath='best_model.h5',
                                             monitor='val_loss',
                                             save_best_only=True)]
              
              
keras_model_best = keras.models.load_model('best_model.h5')

 

ModelCheckpoint는 3개의 인자를 입력받는다.

filepath는 매 epoch마다 훈련된 모델을 저장할 경로를 입력한다.

monitor는 성능지표를 무엇으로 볼 것 인지를 나타낸다.

save_best_only는 True일 경우, 성능지표 기준으로 가장 좋은 성능을 보이는 모델을 저장한다. False인 경우, 매 에폭마다 모델이 filepath{epoch}으로 저장된다.

훈련이 종료된 후에 훈련된 모델을 곧바로 사용하지 않고, 저장된 best model을 불러와서 사용한다면 위에서 지적했던 early stopping 단독으로 인한 문제를 극복할 수 있다.