tf.train.Saver 객체를 통해서 모델을 저장하고 불러올 수 있다.
1. 모델 저장
# saver 객체 생성. # max_to_keep를 안쓰면, 최대 5개만 저장된다. (이전에 저장된 checkpoint가 지워짐) saver = tf.train.Saver(max_to_keep=10) # 세션에 있는 variable들을 저장. # 총 3가지 파일이 한번에 저장된다. 파일명.data-00000-of-00001, .index, .meta # 아직 세개가 뭘 의미하는 지는 잘 모르지만, meta는 모델 구조를 명시한다. # global_step을 써주면, baseline_model-100.meta 처럼 저장되어 유용함. saver.save(sess, "./checkpoints/baseline_model", global_step=step)
2. 모델 불러오기
두 가지 방법이 있다.
2.1 먼저 코드로 학습된 모델을 작성하고(train에 사용한 네트워크 코드를 그대로 사용), 해당하는 웨이트들에 덮어씌우는 방식 (웨이트 범위를 설정해서 특정 부분만 load할 수도 있다.)
2.2 meta 파일을 이용하여, 모델 코드 없이 바로 불러서 사용하기 (keras와 매우 유사)
일단 두 번째 방법만 정리하였다.
하나하나 설명하는 것보다 예시 코드를 보여주는게 나을거같아서 통으로 올림
import numpy as np import tensorflow as tf import os import dataset batch_size = 128 model_name = 'baseline_model' step_list = [600, 700, 800, 900] model_names = [] for step in step_list: model_names.append(model_name + '-' + str(step)) print('dataset loading... ') DB = dataset.SVPC_Dataset('./dataset/test.csv', 'test', divide_value=10000000, round=4, normaliza=False) print('dataset loaded !! ') for idx, model_name in enumerate(model_names): writer = dataset.Result_Writer(fileName=model_name) saver = tf.train.import_meta_graph(os.path.join('./checkpoints',model_name + '.meta')) print('{}: predicting...{}/{}'.format(model_names, idx, len(model_names))) with tf.Session() as sess: saver.restore(sess, os.path.join('./checkpoints',model_name)) graph = tf.get_default_graph() x = graph.get_tensor_by_name("input:0") y_pred = graph.get_tensor_by_name("output:0") for idx, (test_x, id) in enumerate(DB.generator(batch_size)): predicted_y = sess.run(y_pred, feed_dict={x: test_x}) predicted_y[np.where(predicted_y < 0)] = 0. writer.add_rows(id, (predicted_y * 10000).astype('int32')) writer.save_to_csv() print('result saved; {}'.format(model_name + '.csv'))
'실습 > tensorflow' 카테고리의 다른 글
Dataset 생성 및 사용 (0) | 2018.07.15 |
---|---|
버전 별 설치 (0) | 2017.08.15 |
The TensorFlow library wasn't compiled to use SSE instructions but these are available on your machine and could speed up CPU computations (0) | 2017.06.10 |