실습/keras

loss weight 추가하기 및 학습 도중 loss weight 바꾸기

gldmg 2018. 1. 13. 12:01

import keras.backend as K


w1 = K.variable(1.2)

w2 = K.variable(1.6)


model.compile( .., loss_weights=[w1, w2], ..) # loss별 weight 부여


class Dynamic_loss_weights(Callback): # 콜백클래스 상속

    def __init__(self, w1, w2):

        self.w1 = w1

        self.w2 = w2

    def on_epoch_end(self, epoch, log={}): # epoch 끝날 때마다 호출됨

        K.set_value(self.w1, K.get_value(self.w1) + 0.1)

        K.set_value(self.w2, K.get_value(self.w2) - 0.1)


new_callback = Dynamic_loss_weights(w1, w2)


model.fit(..., callbacks = new_callback, ...) # 여러 콜백클래스를 사용하면, 콜백리스트에 append 시켜주면 됨

# model.fit_generator 도 같은 방법으로 추가