mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-08 16:53:28 +03:00
280 lines
9.8 KiB
Python
Executable File
280 lines
9.8 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from tensorflow.contrib import slim
|
|
from tensorflow.contrib.slim.python.slim.learning import train_step
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.platform import gfile
|
|
from tensorflow.python.platform import tf_logging
|
|
from tensorflow.python.saved_model import signature_constants
|
|
from tensorflow.python.saved_model import tag_constants
|
|
from tensorflow.python.tools import optimize_for_inference_lib
|
|
|
|
import argparse
|
|
import json
|
|
import nazoru.core as nazoru
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import tensorflow as tf
|
|
import zipfile
|
|
|
|
FLAGS = None
|
|
|
|
def load_data(kanas):
|
|
|
|
def get_ndjson_path(zip_path):
|
|
dir_path, filename = os.path.split(zip_path)
|
|
body, ext = os.path.splitext(filename)
|
|
return dir_path, os.path.join(dir_path, '%s.ndjson' % (body))
|
|
|
|
data = []
|
|
labels = []
|
|
dir_path, ndjson_path = get_ndjson_path(FLAGS.stroke_data)
|
|
if not os.path.exists(ndjson_path):
|
|
with zipfile.ZipFile(FLAGS.stroke_data, 'r') as f:
|
|
f.extractall(dir_path)
|
|
with open(ndjson_path) as f:
|
|
for line in f.readlines():
|
|
line = json.loads(line)
|
|
if not line['kana'].lower() in kanas: continue
|
|
keydowns = [(e[0], e[2]) for e in line['events'] if e[1] == 'down']
|
|
# TODO (tushuhei) Relax this condition to accept shorter strokes.
|
|
if len(keydowns) < 5: continue # Ignore if too short.
|
|
data.append(keydowns)
|
|
labels.append(kanas.index(line['kana'].lower()))
|
|
labels = np.eye(len(kanas))[labels] # Idiom to one-hot encode.
|
|
return data, labels
|
|
|
|
|
|
def convert_data_to_tensors(x, y):
|
|
inputs = tf.constant(x, dtype=tf.float32)
|
|
inputs.set_shape([None] + list(x[0].shape))
|
|
outputs = tf.constant(y)
|
|
outputs.set_shape([None] + list(y[0].shape))
|
|
return inputs, outputs
|
|
|
|
|
|
def save_inference_graph(checkpoint_dir, input_shape, num_classes, conv_defs,
|
|
output_graph, optimized_output_graph,
|
|
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
|
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
|
with tf.Graph().as_default():
|
|
inputs = tf.placeholder(
|
|
tf.float32, shape=[1] + list(input_shape), name=input_node_name)
|
|
outputs, _ = nazoru.nazorunet(
|
|
inputs, num_classes=num_classes, conv_defs=conv_defs, is_training=False,
|
|
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE)
|
|
sv = tf.train.Supervisor(logdir=checkpoint_dir)
|
|
with sv.managed_session() as sess:
|
|
output_graph_def = graph_util.convert_variables_to_constants(
|
|
sess, sess.graph.as_graph_def(), [output_node_name])
|
|
# TODO (tushuhei) Maybe we don't need to export unoptimized graph.
|
|
with gfile.FastGFile(output_graph, 'wb') as f:
|
|
f.write(output_graph_def.SerializeToString())
|
|
|
|
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
|
|
output_graph_def, [input_node_name], [output_node_name],
|
|
dtypes.float32.as_datatype_enum)
|
|
|
|
with gfile.FastGFile(optimized_output_graph, 'wb') as f:
|
|
f.write(output_graph_def.SerializeToString())
|
|
|
|
def export_saved_model_from_pb(saved_model_dir, graph_name,
|
|
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
|
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
|
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
|
|
with tf.gfile.GFile(graph_name, 'rb') as f:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString(f.read())
|
|
sigs = {}
|
|
|
|
with tf.Session(graph=tf.Graph()) as sess:
|
|
tf.import_graph_def(graph_def, name='')
|
|
g = tf.get_default_graph()
|
|
|
|
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
|
tf.saved_model.signature_def_utils.predict_signature_def(
|
|
{input_node_name: g.get_tensor_by_name('%s:0' % (input_node_name))},
|
|
{output_node_name: g.get_tensor_by_name('%s:0' % (output_node_name))
|
|
})
|
|
builder.add_meta_graph_and_variables(
|
|
sess, [tag_constants.SERVING], signature_def_map=sigs)
|
|
|
|
builder.save()
|
|
|
|
|
|
def main(_):
|
|
kanas = nazoru.lib.KANAS
|
|
keydown_strokes, labels = load_data(kanas)
|
|
x = np.array([
|
|
nazoru.lib.keydowns2image(
|
|
keydown_stroke,
|
|
directional_feature=not FLAGS.no_directional_feature,
|
|
temporal_feature=not FLAGS.no_temporal_feature,
|
|
scale=FLAGS.image_size,
|
|
stroke_width=FLAGS.stroke_width)
|
|
for keydown_stroke in keydown_strokes])
|
|
train_x, train_t, val_x, val_t, test_x, test_t = nazoru.lib.split_data(
|
|
x, labels, 0.2, 0.2)
|
|
conv_defs = [
|
|
nazoru.Conv(kernel=[3, 3], stride=2, depth=32),
|
|
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
|
|
nazoru.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
|
|
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
|
]
|
|
|
|
with tf.Graph().as_default():
|
|
tf.logging.set_verbosity(tf.logging.INFO)
|
|
|
|
train_x, train_t = convert_data_to_tensors(train_x, train_t)
|
|
val_x, val_t = convert_data_to_tensors(val_x, val_t)
|
|
|
|
# Make the model.
|
|
train_logits, train_endpoints = nazoru.nazorunet(
|
|
train_x, num_classes=len(kanas), conv_defs=conv_defs,
|
|
dropout_keep_prob=FLAGS.dropout_keep_prob, scope=nazoru.lib.SCOPE,
|
|
reuse=tf.AUTO_REUSE, is_training=True)
|
|
val_logits, val_endpoints = nazoru.nazorunet(
|
|
val_x, num_classes=len(kanas), conv_defs=conv_defs,
|
|
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE,
|
|
is_training=True)
|
|
|
|
# Add the loss function to the graph.
|
|
tf.losses.softmax_cross_entropy(train_t, train_logits)
|
|
train_total_loss = tf.losses.get_total_loss()
|
|
tf.summary.scalar('train_total_loss', train_total_loss)
|
|
|
|
val_total_loss = tf.losses.get_total_loss()
|
|
tf.summary.scalar('val_total_loss', val_total_loss)
|
|
|
|
train_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
|
tf.argmax(train_logits, axis=1), tf.argmax(train_t, axis=1)), tf.float32))
|
|
tf.summary.scalar('train_accuracy', train_accuracy)
|
|
|
|
val_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
|
tf.argmax(val_logits, axis=1), tf.argmax(val_t, axis=1)), tf.float32))
|
|
tf.summary.scalar('val_accuracy', val_accuracy)
|
|
|
|
# Specify the optimizer and create the train op:
|
|
optimizer = tf.train.AdamOptimizer(learning_rate=0.005)
|
|
train_op = slim.learning.create_train_op(train_total_loss, optimizer)
|
|
|
|
def train_step_fn(sess, *args, **kwargs):
|
|
total_loss, should_stop = train_step(sess, *args, **kwargs)
|
|
if train_step_fn.step % FLAGS.n_steps_to_log == 0:
|
|
val_acc = sess.run(val_accuracy)
|
|
tf_logging.info('step: %d, validation accuracy: %.3f' % (
|
|
train_step_fn.step, val_acc))
|
|
train_step_fn.step += 1
|
|
return [total_loss, should_stop]
|
|
train_step_fn.step = 0
|
|
|
|
# Run the training inside a session.
|
|
final_loss = slim.learning.train(
|
|
train_op,
|
|
logdir=FLAGS.checkpoint_dir,
|
|
number_of_steps=FLAGS.n_train_steps,
|
|
train_step_fn=train_step_fn,
|
|
save_summaries_secs=5,
|
|
log_every_n_steps=FLAGS.n_steps_to_log)
|
|
|
|
save_inference_graph(checkpoint_dir=FLAGS.checkpoint_dir,
|
|
input_shape=train_x.shape[1:], num_classes=len(kanas),
|
|
conv_defs=conv_defs, output_graph=FLAGS.output_graph,
|
|
optimized_output_graph=FLAGS.optimized_output_graph)
|
|
export_saved_model_from_pb(saved_model_dir=FLAGS.saved_model_dir,
|
|
graph_name=FLAGS.optimized_output_graph)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--image_size',
|
|
type=int,
|
|
default=16,
|
|
help='Size of the square canvas to render input strokes.'
|
|
)
|
|
parser.add_argument(
|
|
'--stroke_width',
|
|
type=int,
|
|
default=2,
|
|
help='Stroke width of rendered strokes.'
|
|
)
|
|
parser.add_argument(
|
|
'--n_train_steps',
|
|
type=int,
|
|
default=500,
|
|
help='Number of training steps to run.'
|
|
)
|
|
parser.add_argument(
|
|
'--n_steps_to_log',
|
|
type=int,
|
|
default=10,
|
|
help='How often to print log during traing.'
|
|
)
|
|
parser.add_argument(
|
|
'--checkpoint_dir',
|
|
type=str,
|
|
default=os.path.join(os.sep, 'tmp', 'nazorunet'),
|
|
help='Where to save checkpoint files.'
|
|
)
|
|
parser.add_argument(
|
|
'--output_graph',
|
|
type=str,
|
|
default='nazoru.pb',
|
|
help='Where to save the trained graph.'
|
|
)
|
|
parser.add_argument(
|
|
'--optimized_output_graph',
|
|
type=str,
|
|
default='optimized_nazoru.pb',
|
|
help='Where to save the trained graph optimized for inference.'
|
|
)
|
|
parser.add_argument(
|
|
'--saved_model_dir',
|
|
type=str,
|
|
default='nazoru_saved_model',
|
|
help='Where to save the exported graph.'
|
|
)
|
|
parser.add_argument(
|
|
'--dropout_keep_prob',
|
|
type=float,
|
|
default=0.8,
|
|
help='The percentage of activation values that are retained in dropout.'
|
|
)
|
|
parser.add_argument(
|
|
'stroke_data',
|
|
type=str,
|
|
help='Path to zipped stroke data to input. You can download the ' +
|
|
'default stroke data at ' +
|
|
'https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip.'
|
|
)
|
|
parser.add_argument(
|
|
'--no_directional_feature',
|
|
action='store_false',
|
|
help='Not to use directional feature.'
|
|
)
|
|
parser.add_argument(
|
|
'--no_temporal_feature',
|
|
action='store_false',
|
|
help='Not to use temporal feature.'
|
|
)
|
|
|
|
FLAGS, unparsed = parser.parse_known_args()
|
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|