mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-08 16:53:28 +03:00
@@ -15,8 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from tensorflow.contrib import slim
|
import tf_slim as slim
|
||||||
from tensorflow.contrib.slim.python.slim.learning import train_step
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
@@ -31,7 +30,7 @@ import nazoru.core as nazoru
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tensorflow as tf
|
import tensorflow.compat.v1 as tf
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
@@ -139,6 +138,7 @@ def main(_):
|
|||||||
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tf.disable_eager_execution()
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ def main(_):
|
|||||||
train_op = slim.learning.create_train_op(train_total_loss, optimizer)
|
train_op = slim.learning.create_train_op(train_total_loss, optimizer)
|
||||||
|
|
||||||
def train_step_fn(sess, *args, **kwargs):
|
def train_step_fn(sess, *args, **kwargs):
|
||||||
total_loss, should_stop = train_step(sess, *args, **kwargs)
|
total_loss, should_stop = slim.learning.train_step(sess, *args, **kwargs)
|
||||||
if train_step_fn.step % FLAGS.n_steps_to_log == 0:
|
if train_step_fn.step % FLAGS.n_steps_to_log == 0:
|
||||||
val_acc = sess.run(val_accuracy)
|
val_acc = sess.run(val_accuracy)
|
||||||
tf_logging.info('step: %d, validation accuracy: %.3f' % (
|
tf_logging.info('step: %d, validation accuracy: %.3f' % (
|
||||||
@@ -236,19 +236,19 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output_graph',
|
'--output_graph',
|
||||||
type=str,
|
type=str,
|
||||||
default='nazoru.pb',
|
default='nazoru_custom.pb',
|
||||||
help='Where to save the trained graph.'
|
help='Where to save the trained graph.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--optimized_output_graph',
|
'--optimized_output_graph',
|
||||||
type=str,
|
type=str,
|
||||||
default='optimized_nazoru.pb',
|
default='optimized_nazoru_custom.pb',
|
||||||
help='Where to save the trained graph optimized for inference.'
|
help='Where to save the trained graph optimized for inference.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--saved_model_dir',
|
'--saved_model_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='nazoru_saved_model',
|
default='nazoru_saved_model_custom',
|
||||||
help='Where to save the exported graph.'
|
help='Where to save the exported graph.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -59,12 +59,10 @@ setup(
|
|||||||
# data_files=[('/etc/systemd/system', ['data/nazoru.service'])],
|
# data_files=[('/etc/systemd/system', ['data/nazoru.service'])],
|
||||||
|
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'np_utils',
|
'cairocffi',
|
||||||
'cairocffi<=1.0.0',
|
|
||||||
'h5py',
|
|
||||||
'pillow',
|
'pillow',
|
||||||
'tensorflow~=1.15.4',
|
'tensorflow~=2.5.1',
|
||||||
'markdown<=3.0.1',
|
'tf_slim~=1.1.0',
|
||||||
'enum34;python_version<"3.4"',
|
'enum34;python_version<"3.4"',
|
||||||
'pyserial',
|
'pyserial',
|
||||||
'evdev;platform_system=="Linux"',
|
'evdev;platform_system=="Linux"',
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ predict input characters from visualized trace.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from tensorflow.contrib import slim
|
import tf_slim as slim
|
||||||
import tensorflow as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
||||||
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
||||||
@@ -161,7 +161,7 @@ def nazorunet(inputs,
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
min_depth=8,
|
min_depth=8,
|
||||||
depth_multiplier=1.0,
|
depth_multiplier=1.0,
|
||||||
prediction_fn=tf.contrib.layers.softmax,
|
prediction_fn=slim.layers.softmax,
|
||||||
spatial_squeeze=True,
|
spatial_squeeze=True,
|
||||||
reuse=None,
|
reuse=None,
|
||||||
scope='NazoruNet',
|
scope='NazoruNet',
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
def _load_graph(model_file):
|
def _load_graph(model_file):
|
||||||
graph = tf.Graph()
|
graph = tf.Graph()
|
||||||
graph_def = tf.GraphDef()
|
graph_def = tf.compat.v1.GraphDef()
|
||||||
with open(model_file, "rb") as f:
|
with open(model_file, "rb") as f:
|
||||||
graph_def.ParseFromString(f.read())
|
graph_def.ParseFromString(f.read())
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@@ -44,7 +44,7 @@ class NazoruPredictor():
|
|||||||
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
||||||
inputs = np.expand_dims(inputs, axis=0)
|
inputs = np.expand_dims(inputs, axis=0)
|
||||||
with utils.Measure('sess.run'):
|
with utils.Measure('sess.run'):
|
||||||
with tf.Session(graph=self._graph) as sess:
|
with tf.compat.v1.Session(graph=self._graph) as sess:
|
||||||
result = sess.run(self._output_operation.outputs[0],
|
result = sess.run(self._output_operation.outputs[0],
|
||||||
{self._input_operation.outputs[0]: inputs})[0]
|
{self._input_operation.outputs[0]: inputs})[0]
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user