Coding

Inspecting Pretrained Tensorflow model

linguana 2021. 6. 10. 12:42

사전학습된 모델을 받았을 때, 모델의 입력과 출력의 노드를 알아야 한다.

텐서보드를 활용해보도록 하자.


 

1. pip으로 텐서보드 설치.

pip install tensorboard

2. cmd에서 텐서보드 실행

텐서보드가 설치 되었다면 cmd에서 다음 명령어로 텐서보드를 실행할 수 있다.

tensorboard --logdir=/tmp/tensorflow_logdir

그러면 다음과 같이 실행되는 것을 확인할 수 있다:

cmd에서 텐서보드를 실행한 사진


3. 웹 브라우저를 키고, `localhost:6006`를 url에 입력하자.

localhost:6006 웹브라우저 모습


4. import_pb_to_tensorboard.py 실행하기

[2]에서 가져온 다음의 코드를 실행하면 된다.

더보기
# import_pb_to_tensorboard.py
"""Imports a protobuf model as a graph in Tensorboard."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from absl import app

from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.summary import summary
from tensorflow.python.tools import saved_model_utils

# Try importing TensorRT ops if available
# TODO(aaroey): ideally we should import everything from contrib, but currently
# tensorrt module would cause build errors when being imported in
# tensorflow/contrib/__init__.py. Fix it.
# pylint: disable=unused-import,g-import-not-at-top,wildcard-import
try:
  from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
except ImportError:
  pass
# pylint: enable=unused-import,g-import-not-at-top,wildcard-import


def import_to_tensorboard(model_dir, log_dir, tag_set):
  """View an SavedModel as a graph in Tensorboard.
  Args:
    model_dir: The directory containing the SavedModel to import.
    log_dir: The location for the Tensorboard log to begin visualization from.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
      separated by ','. For tag-set contains multiple tags, all tags must be
      passed in.
  Usage: Call this function with your SavedModel location and desired log
    directory. Launch Tensorboard by pointing it to the log directory. View your
    imported SavedModel as a graph.
  """
  with session.Session(graph=ops.Graph()) as sess:
    input_graph_def = saved_model_utils.get_meta_graph_def(model_dir,
                                                           tag_set).graph_def
    importer.import_graph_def(input_graph_def)

    pb_visual_writer = summary.FileWriter(log_dir)
    pb_visual_writer.add_graph(sess.graph)
    print("Model Imported. Visualize by running: "
          "tensorboard --logdir={}".format(log_dir))


def main(_):
  import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir, FLAGS.tag_set)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
      "--model_dir",
      type=str,
      default="",
      required=True,
      help="The directory containing the SavedModel to import.")
  parser.add_argument(
      "--log_dir",
      type=str,
      default="",
      required=True,
      help="The location for the Tensorboard log to begin visualization from.")
  parser.add_argument(
      "--tag_set",
      type=str,
      default="serve",
      required=False,
      help='tag-set of graph in SavedModel to load, separated by \',\'')
  FLAGS, unparsed = parser.parse_known_args()
  app.run(main=main, argv=[sys.argv[0]] + unparsed)

cmd를 열어서 다음과 같이 실행하자.
- --model_dir에는 pb 파일이 있는 경로를 인자로 넘겨주고
- --log_dir 에는 텐서보드가 저장될 경로를 인자로 넘져주자. 

$ python import_pb_to_tensorboard.py --model_dir /tmp/mnist_model_graph.pb --log_dir /tmp/tensorflow_logdir

5. 텐서보드 실행하기

tensorboard --logdir=/tmp/tensorflow_logdir

6. 브라우저에서 확인하기

http://localhost:6006 에서 모델을 확인할 수 있다.

 

 


Reference

[1] How to inspect a pre-trained TensorFlow model | by Dan Jarvis | Medium

[2] tensorflow/import_pb_to_tensorboard.py at master · tensorflow/tensorflow (github.com)

[3] 파이쿵 :: 텐서보드 사용법 (tistory.com)