mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-09 01:03:26 +03:00
Update TF version
This commit is contained in:
@@ -20,8 +20,8 @@ predict input characters from visualized trace.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
from tensorflow.contrib import slim
|
||||
import tensorflow as tf
|
||||
import tf_slim as slim
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
||||
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
||||
@@ -161,7 +161,7 @@ def nazorunet(inputs,
|
||||
is_training=True,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
prediction_fn=tf.contrib.layers.softmax,
|
||||
prediction_fn=slim.layers.softmax,
|
||||
spatial_squeeze=True,
|
||||
reuse=None,
|
||||
scope='NazoruNet',
|
||||
|
||||
@@ -23,7 +23,7 @@ import numpy as np
|
||||
|
||||
def _load_graph(model_file):
|
||||
graph = tf.Graph()
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def = tf.compat.v1.GraphDef()
|
||||
with open(model_file, "rb") as f:
|
||||
graph_def.ParseFromString(f.read())
|
||||
with graph.as_default():
|
||||
@@ -44,7 +44,7 @@ class NazoruPredictor():
|
||||
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
||||
inputs = np.expand_dims(inputs, axis=0)
|
||||
with utils.Measure('sess.run'):
|
||||
with tf.Session(graph=self._graph) as sess:
|
||||
with tf.compat.v1.Session(graph=self._graph) as sess:
|
||||
result = sess.run(self._output_operation.outputs[0],
|
||||
{self._input_operation.outputs[0]: inputs})[0]
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user