tf.train.Checkpoint
출처: www.tensorflow.org/api_docs/python/tf/train/Checkpoint
tf.train.Checkpoint | TensorFlow Core v2.4.1
Manages saving/restoring trackable values to disk.
www.tensorflow.org
하는 일: 추적할 수 있는 값(trackable value)을 디스크(disk)에서 관리(manage)한다 (영어 잘하쥬?)
기본 형식:
tf.train.Checkpoint(
root=None, **kwargs
)
추적 가능한 오브젝트 (trackable objects):
- tf.Variable,
- tf.keras.optimizers.Optimizer,
- tf.data.Dataset,
- tf.keras.Layer,
- tf.keras.Model 따위
첫번째 예제:
model = tf.keras.Model(어쩌고저쩌고)
checkpoint = tf.train.Checkpoint(model)
# 체크포인트 저장해라
save_path = checkpoint.save('패ㄷ쓰')
# 저장해둔 거 다시 복원해라
checkpoint.restore(save_path)
두번째 예제:
import tensorflow as tf
import os
checkpoint_directory = '패ㄸ쓰'
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# 어이, 그 체크포인트 만들어서 두 개 오브젝트 좀 관리해야겠다
# 하나는 거시기 "optimizer"라고 하고, 다른 하나는 "model"이라고 혀
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)
for _ in range(num_training_steps):
optimizer.minimize(어쩌고저쩌고) # 변수들이 복구된댜
status.assert_consumed() # Optional sanity checks.라고 하는데 '더블체크'한다는 의미인 듯
checkpoint.save(file_prefix=checkpoint_prefix)
### sanity 개념이 뭘까? ###
메소드 자기들 뭐가뭐가 있나? (read, restore, save, write)
1. read
메소드 read 와 write은 상호 반대 개념이고 짝짝꿍으로 쓰인다. 뭐, 쓰고 그리곤 읽고 이런 식인듯.
그니까 read는 save 말고 write랑 친구 먹었따.
# Create a checkpoint with write()
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
path = ckpt.write('/tmp/my_checkpoint')
# Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed()
# You can also pass options to read(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.read(path, options=options)
2. restore
save는 그럼 누구랑 친구냐? restore이랑 친구다! save는 write랑 조금 더 다른 게, save는 save_counter이란 변수도 기억한덴다. restore는 (1) 만약에 복원할 변수가 이미 만들어졌으면 변수에 할당을 곧바로 하던가 (2) 변수가 만들어지기까지 복원을 늦춰준다고 한다. (restore 좀 착한듯?) restore이 호출되고 난 후에 아까 말한 변수라던가 하는 혹부리 같은 놈들 (의존하는 것들) 은 걔네들이랑 대응되는 오브젝트가 체크포인트에 있으면 매칭이 된다고 한다. restore이란 애가 checkpoint에 있는 추적가능한 오브젝트들을 쫙 훓어보고 지 짝궁 기다리고 있으면 고것들 매칭되라고 request를 넣어준다고 한다. (restore 약간 주선자 삘이네? 중매인? 아무튼)
아 근데 이렇게 마음씨 좋은 restore은 마냥 매칭해주려고 기다리게 하면 안 될 거 아녀. 처리할 게 한 바가지 쌓여있는데 하루 죙일 저거 뭐 주선해준다고 기다려 봐. 일이 되나. 그래서, assert_consumed()를 써서 로딩이 완료가 되었다고 도장 꽉 찍어서 더 이상 할당할 게 (매칭해줄 게) 없다고 해주는 게 안전하다.
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path).assert_consumed()
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()
아 근데 이거 너무 매정한 거 아니냐고 따진다면 뭐 걱정 마시라. 오브젝트(구혼남)가 체크포인트에서 할당될 값(예비 부인)을 못 찾거나 체크포인트 된 값(구혼녀)이 매칭된 오브젝트(예비 신랑)이 없으면 예외가 뜬다고 한다. 그러니까 아주 해피해피하게 다들 매칭될 거니까 뭣들 하나 걱정하지 않아도 된다.
그리고 좀 덧붙여서 말해주면... (엑기스는 다 말했고 좀 짜잘한 거)
(가) 텐서플로 1버전에 있던 tf.compat.v1.train.Saver은 name-based, 그니까 이름 기반인데 (이름으로 값 매칭한댜) 2 버전에서 이 tf.train.Checkpoint.save로 사용이 가능하다. 그니까 저 구닥다리 Saver를 구하고 싶으면(save) 체크포인트의 save 쓰시라 이 말씀이여.
(나) 케라스 SavedModel랑 호환되게 쓰려면 이렇게 써
model = tf.keras.Model(...)
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()
expect_partial()를 쓴 이유는, 케라스에서 요구하는 값들이 무쟈게 많아서 "뭐가 안 쓰였네" "저게 안 쓰였네"하면서 경고가 아주 무섭게 나올 거라서, 그런 거 보기 싫으니까 "아 몰라, 일부만 복원될 거니까, 그렇게 알아!"라고 미리 선빵치는 거임.
3. save
아주아주 기본적인 메소드지. 요거 사용해서 체크포인트 저장해주면:
- 체크포인트가 추적해주는 오브젝트가 만든 변수
- 체크포인트가 추적해주는 오브젝트에 의존하는 어떤 추적가능한 오브젝트들이
체크포인트에 포함이 된다, 이 말이여. 물론 Checkpoint.save()가 호출된 시점이 기준이지. 쉽게 말하면, "쫄다구들 다 포함해서 추적해드립니다, 고객님 ^^" 이거다.
뭐, 아까도 말했지만 save는 누구랑 짝짝꿍이다? restore랑 절친이다. 긴 말 필요 없고 코드로 보자.
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.save("/tmp/ckpt")
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt").assert_consumed()
# You can also pass options to save() and restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)
# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
제대로 된 작동이나 뭐 더 고급진 기능은 모르겠고, 암튼 그냥 외워. save는 restore.
save & restore
write & read
4. write
위에 save 설명한 거랑 비슷한데, 얘는 좀 뭔가 귀찮아 해. 많은 거 하고 싶지 않아.
뭔 말이냐면,
- checkpoint도 안 세고,
- save_counter도 안 세고(업데이트 안 해줌),
- tf.train.latest_checkpoint가 쓴 메타데이터도 업데이트 안 해준데.
쉽게 생각하면 save보다 좀 덜 떨어진 놈이지 (하위호환 정도라고 생각하자). 그리고 얘랑 누가 친구다? read다. 코드 보자:
step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.write("/tmp/ckpt")
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt").assert_consumed()
# You can also pass options to write() and read(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("/tmp/ckpt", options=options)
# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
위에 나온 코드랑 완.전. 똑같지? ctrl+c, ctrl+v 정도 급인데, 그냥 (1) save → write, (2) restore → read로 바뀐 거 뿐임. 다시 한 번 강조한다, 외워.
save & restore
write & read
-THE END-