mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-08 16:53:28 +03:00
Add nazoru-input in mozc-nazoru/
This commit is contained in:
77
mozc-nazoru/bin/nazoru-input
Executable file
77
mozc-nazoru/bin/nazoru-input
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from nazoru.led import LED_BLUE, LED_RED
|
||||
LED_RED.blink(1)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from nazoru import get_default_graph_path
|
||||
from nazoru.core import create_keyboard_recorder, Bluetooth, NazoruPredictor
|
||||
|
||||
def main():
|
||||
FLAGS = None
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-g', '--graph',
|
||||
type=str,
|
||||
default=get_default_graph_path(),
|
||||
help='Path to a trained model which is generated by ' +
|
||||
'nazoru-training.')
|
||||
parser.add_argument('-v', '--verbose', action='store_true')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
LED_RED.blink(0.3)
|
||||
|
||||
bt_connection = Bluetooth()
|
||||
try:
|
||||
recorder = create_keyboard_recorder(verbose=FLAGS.verbose)
|
||||
except IOError as e:
|
||||
LED_RED.off()
|
||||
LED_BLUE.off()
|
||||
raise e
|
||||
predictor = NazoruPredictor(FLAGS.graph)
|
||||
|
||||
LED_RED.off()
|
||||
LED_BLUE.blink(1)
|
||||
|
||||
print('Ready. Please input your scrrible.')
|
||||
while True:
|
||||
data, command = recorder.record()
|
||||
if command is not None:
|
||||
print('command: %s' % command)
|
||||
bt_connection.command(command)
|
||||
continue
|
||||
if data is None:
|
||||
print('done.')
|
||||
break
|
||||
|
||||
LED_RED.on()
|
||||
result = predictor.predict_top_n(data, 5)
|
||||
LED_RED.off()
|
||||
|
||||
print('\n=== RESULTS ===')
|
||||
for item in result:
|
||||
print(u'%s (%s): %.5f' % (item[0], item[1], item[2]))
|
||||
print('===============\n')
|
||||
|
||||
most_likely_result = result[0]
|
||||
print(u'%s (%s)' % (most_likely_result[0], most_likely_result[1]))
|
||||
bt_connection.send(most_likely_result[1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
279
mozc-nazoru/bin/nazoru-training
Executable file
279
mozc-nazoru/bin/nazoru-training
Executable file
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib.slim.python.slim.learning import train_step
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import optimize_for_inference_lib
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import nazoru.core as nazoru
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import zipfile
|
||||
|
||||
FLAGS = None
|
||||
|
||||
def load_data(kanas):
|
||||
|
||||
def get_ndjson_path(zip_path):
|
||||
dir_path, filename = os.path.split(zip_path)
|
||||
body, ext = os.path.splitext(filename)
|
||||
return dir_path, os.path.join(dir_path, '%s.ndjson' % (body))
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
dir_path, ndjson_path = get_ndjson_path(FLAGS.stroke_data)
|
||||
if not os.path.exists(ndjson_path):
|
||||
with zipfile.ZipFile(FLAGS.stroke_data, 'r') as f:
|
||||
f.extractall(dir_path)
|
||||
with open(ndjson_path) as f:
|
||||
for line in f.readlines():
|
||||
line = json.loads(line)
|
||||
if not line['kana'].lower() in kanas: continue
|
||||
keydowns = [(e[0], e[2]) for e in line['events'] if e[1] == 'down']
|
||||
# TODO (tushuhei) Relax this condition to accept shorter strokes.
|
||||
if len(keydowns) < 5: continue # Ignore if too short.
|
||||
data.append(keydowns)
|
||||
labels.append(kanas.index(line['kana'].lower()))
|
||||
labels = np.eye(len(kanas))[labels] # Idiom to one-hot encode.
|
||||
return data, labels
|
||||
|
||||
|
||||
def convert_data_to_tensors(x, y):
|
||||
inputs = tf.constant(x, dtype=tf.float32)
|
||||
inputs.set_shape([None] + list(x[0].shape))
|
||||
outputs = tf.constant(y)
|
||||
outputs.set_shape([None] + list(y[0].shape))
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def save_inference_graph(checkpoint_dir, input_shape, num_classes, conv_defs,
|
||||
output_graph, optimized_output_graph,
|
||||
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
||||
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
||||
with tf.Graph().as_default():
|
||||
inputs = tf.placeholder(
|
||||
tf.float32, shape=[1] + list(input_shape), name=input_node_name)
|
||||
outputs, _ = nazoru.nazorunet(
|
||||
inputs, num_classes=num_classes, conv_defs=conv_defs, is_training=False,
|
||||
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE)
|
||||
sv = tf.train.Supervisor(logdir=checkpoint_dir)
|
||||
with sv.managed_session() as sess:
|
||||
output_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph.as_graph_def(), [output_node_name])
|
||||
# TODO (tushuhei) Maybe we don't need to export unoptimized graph.
|
||||
with gfile.FastGFile(output_graph, 'wb') as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
|
||||
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
|
||||
output_graph_def, [input_node_name], [output_node_name],
|
||||
dtypes.float32.as_datatype_enum)
|
||||
|
||||
with gfile.FastGFile(optimized_output_graph, 'wb') as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
|
||||
def export_saved_model_from_pb(saved_model_dir, graph_name,
|
||||
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
||||
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
|
||||
with tf.gfile.GFile(graph_name, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sigs = {}
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
g = tf.get_default_graph()
|
||||
|
||||
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
||||
tf.saved_model.signature_def_utils.predict_signature_def(
|
||||
{input_node_name: g.get_tensor_by_name('%s:0' % (input_node_name))},
|
||||
{output_node_name: g.get_tensor_by_name('%s:0' % (output_node_name))
|
||||
})
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, [tag_constants.SERVING], signature_def_map=sigs)
|
||||
|
||||
builder.save()
|
||||
|
||||
|
||||
def main(_):
|
||||
kanas = nazoru.lib.KANAS
|
||||
keydown_strokes, labels = load_data(kanas)
|
||||
x = np.array([
|
||||
nazoru.lib.keydowns2image(
|
||||
keydown_stroke,
|
||||
directional_feature=not FLAGS.no_directional_feature,
|
||||
temporal_feature=not FLAGS.no_temporal_feature,
|
||||
scale=FLAGS.image_size,
|
||||
stroke_width=FLAGS.stroke_width)
|
||||
for keydown_stroke in keydown_strokes])
|
||||
train_x, train_t, val_x, val_t, test_x, test_t = nazoru.lib.split_data(
|
||||
x, labels, 0.2, 0.2)
|
||||
conv_defs = [
|
||||
nazoru.Conv(kernel=[3, 3], stride=2, depth=32),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
||||
]
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
train_x, train_t = convert_data_to_tensors(train_x, train_t)
|
||||
val_x, val_t = convert_data_to_tensors(val_x, val_t)
|
||||
|
||||
# Make the model.
|
||||
train_logits, train_endpoints = nazoru.nazorunet(
|
||||
train_x, num_classes=len(kanas), conv_defs=conv_defs,
|
||||
dropout_keep_prob=FLAGS.dropout_keep_prob, scope=nazoru.lib.SCOPE,
|
||||
reuse=tf.AUTO_REUSE, is_training=True)
|
||||
val_logits, val_endpoints = nazoru.nazorunet(
|
||||
val_x, num_classes=len(kanas), conv_defs=conv_defs,
|
||||
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE,
|
||||
is_training=True)
|
||||
|
||||
# Add the loss function to the graph.
|
||||
tf.losses.softmax_cross_entropy(train_t, train_logits)
|
||||
train_total_loss = tf.losses.get_total_loss()
|
||||
tf.summary.scalar('train_total_loss', train_total_loss)
|
||||
|
||||
val_total_loss = tf.losses.get_total_loss()
|
||||
tf.summary.scalar('val_total_loss', val_total_loss)
|
||||
|
||||
train_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
||||
tf.argmax(train_logits, axis=1), tf.argmax(train_t, axis=1)), tf.float32))
|
||||
tf.summary.scalar('train_accuracy', train_accuracy)
|
||||
|
||||
val_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
||||
tf.argmax(val_logits, axis=1), tf.argmax(val_t, axis=1)), tf.float32))
|
||||
tf.summary.scalar('val_accuracy', val_accuracy)
|
||||
|
||||
# Specify the optimizer and create the train op:
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=0.005)
|
||||
train_op = slim.learning.create_train_op(train_total_loss, optimizer)
|
||||
|
||||
def train_step_fn(sess, *args, **kwargs):
|
||||
total_loss, should_stop = train_step(sess, *args, **kwargs)
|
||||
if train_step_fn.step % FLAGS.n_steps_to_log == 0:
|
||||
val_acc = sess.run(val_accuracy)
|
||||
tf_logging.info('step: %d, validation accuracy: %.3f' % (
|
||||
train_step_fn.step, val_acc))
|
||||
train_step_fn.step += 1
|
||||
return [total_loss, should_stop]
|
||||
train_step_fn.step = 0
|
||||
|
||||
# Run the training inside a session.
|
||||
final_loss = slim.learning.train(
|
||||
train_op,
|
||||
logdir=FLAGS.checkpoint_dir,
|
||||
number_of_steps=FLAGS.n_train_steps,
|
||||
train_step_fn=train_step_fn,
|
||||
save_summaries_secs=5,
|
||||
log_every_n_steps=FLAGS.n_steps_to_log)
|
||||
|
||||
save_inference_graph(checkpoint_dir=FLAGS.checkpoint_dir,
|
||||
input_shape=train_x.shape[1:], num_classes=len(kanas),
|
||||
conv_defs=conv_defs, output_graph=FLAGS.output_graph,
|
||||
optimized_output_graph=FLAGS.optimized_output_graph)
|
||||
export_saved_model_from_pb(saved_model_dir=FLAGS.saved_model_dir,
|
||||
graph_name=FLAGS.optimized_output_graph)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--image_size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Size of the square canvas to render input strokes.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stroke_width',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Stroke width of rendered strokes.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n_train_steps',
|
||||
type=int,
|
||||
default=500,
|
||||
help='Number of training steps to run.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n_steps_to_log',
|
||||
type=int,
|
||||
default=10,
|
||||
help='How often to print log during traing.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
type=str,
|
||||
default=os.path.join(os.sep, 'tmp', 'nazorunet'),
|
||||
help='Where to save checkpoint files.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_graph',
|
||||
type=str,
|
||||
default='nazoru.pb',
|
||||
help='Where to save the trained graph.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--optimized_output_graph',
|
||||
type=str,
|
||||
default='optimized_nazoru.pb',
|
||||
help='Where to save the trained graph optimized for inference.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--saved_model_dir',
|
||||
type=str,
|
||||
default='nazoru_saved_model',
|
||||
help='Where to save the exported graph.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dropout_keep_prob',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='The percentage of activation values that are retained in dropout.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'stroke_data',
|
||||
type=str,
|
||||
help='Path to zipped stroke data to input. You can download the ' +
|
||||
'default stroke data at ' +
|
||||
'https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no_directional_feature',
|
||||
action='store_false',
|
||||
help='Not to use directional feature.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no_temporal_feature',
|
||||
action='store_false',
|
||||
help='Not to use temporal feature.'
|
||||
)
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
Reference in New Issue
Block a user