Update TF version

This commit is contained in:
Shuhei Iitsuka
2021-09-16 12:55:38 +09:00
parent 85c6365235
commit 1ecd083322
4 changed files with 15 additions and 17 deletions

View File

@@ -20,8 +20,8 @@ predict input characters from visualized trace.
"""
from collections import namedtuple
from tensorflow.contrib import slim
import tensorflow as tf
import tf_slim as slim
import tensorflow.compat.v1 as tf
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
@@ -161,7 +161,7 @@ def nazorunet(inputs,
is_training=True,
min_depth=8,
depth_multiplier=1.0,
prediction_fn=tf.contrib.layers.softmax,
prediction_fn=slim.layers.softmax,
spatial_squeeze=True,
reuse=None,
scope='NazoruNet',

View File

@@ -23,7 +23,7 @@ import numpy as np
def _load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
graph_def = tf.compat.v1.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
@@ -44,7 +44,7 @@ class NazoruPredictor():
inputs = lib.keydowns2image(data, True, True, 16, 2)
inputs = np.expand_dims(inputs, axis=0)
with utils.Measure('sess.run'):
with tf.Session(graph=self._graph) as sess:
with tf.compat.v1.Session(graph=self._graph) as sess:
result = sess.run(self._output_operation.outputs[0],
{self._input_operation.outputs[0]: inputs})[0]
return result