#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tensorflow.contrib import slim from tensorflow.contrib.slim.python.slim.learning import train_step from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import optimize_for_inference_lib import argparse import json import nazoru.core as nazoru import numpy as np import os import sys import tensorflow as tf import zipfile FLAGS = None def load_data(kanas): def get_ndjson_path(zip_path): dir_path, filename = os.path.split(zip_path) body, ext = os.path.splitext(filename) return dir_path, os.path.join(dir_path, '%s.ndjson' % (body)) data = [] labels = [] dir_path, ndjson_path = get_ndjson_path(FLAGS.stroke_data) if not os.path.exists(ndjson_path): with zipfile.ZipFile(FLAGS.stroke_data, 'r') as f: f.extractall(dir_path) with open(ndjson_path) as f: for line in f.readlines(): line = json.loads(line) if not line['kana'].lower() in kanas: continue keydowns = [(e[0], e[2]) for e in line['events'] if e[1] == 'down'] # TODO (tushuhei) Relax this condition to accept shorter strokes. if len(keydowns) < 5: continue # Ignore if too short. data.append(keydowns) labels.append(kanas.index(line['kana'].lower())) labels = np.eye(len(kanas))[labels] # Idiom to one-hot encode. return data, labels def convert_data_to_tensors(x, y): inputs = tf.constant(x, dtype=tf.float32) inputs.set_shape([None] + list(x[0].shape)) outputs = tf.constant(y) outputs.set_shape([None] + list(y[0].shape)) return inputs, outputs def save_inference_graph(checkpoint_dir, input_shape, num_classes, conv_defs, output_graph, optimized_output_graph, input_node_name=nazoru.lib.INPUT_NODE_NAME, output_node_name=nazoru.lib.OUTPUT_NODE_NAME): with tf.Graph().as_default(): inputs = tf.placeholder( tf.float32, shape=[1] + list(input_shape), name=input_node_name) outputs, _ = nazoru.nazorunet( inputs, num_classes=num_classes, conv_defs=conv_defs, is_training=False, dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE) sv = tf.train.Supervisor(logdir=checkpoint_dir) with sv.managed_session() as sess: output_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), [output_node_name]) # TODO (tushuhei) Maybe we don't need to export unoptimized graph. with gfile.FastGFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) output_graph_def = optimize_for_inference_lib.optimize_for_inference( output_graph_def, [input_node_name], [output_node_name], dtypes.float32.as_datatype_enum) with gfile.FastGFile(optimized_output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) def export_saved_model_from_pb(saved_model_dir, graph_name, input_node_name=nazoru.lib.INPUT_NODE_NAME, output_node_name=nazoru.lib.OUTPUT_NODE_NAME): builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) with tf.gfile.GFile(graph_name, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sigs = {} with tf.Session(graph=tf.Graph()) as sess: tf.import_graph_def(graph_def, name='') g = tf.get_default_graph() sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ tf.saved_model.signature_def_utils.predict_signature_def( {input_node_name: g.get_tensor_by_name('%s:0' % (input_node_name))}, {output_node_name: g.get_tensor_by_name('%s:0' % (output_node_name)) }) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=sigs) builder.save() def main(_): kanas = nazoru.lib.KANAS keydown_strokes, labels = load_data(kanas) x = np.array([ nazoru.lib.keydowns2image( keydown_stroke, directional_feature=not FLAGS.no_directional_feature, temporal_feature=not FLAGS.no_temporal_feature, scale=FLAGS.image_size, stroke_width=FLAGS.stroke_width) for keydown_stroke in keydown_strokes]) train_x, train_t, val_x, val_t, test_x, test_t = nazoru.lib.split_data( x, labels, 0.2, 0.2) conv_defs = [ nazoru.Conv(kernel=[3, 3], stride=2, depth=32), nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=64), nazoru.DepthSepConv(kernel=[3, 3], stride=2, depth=128), nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128), ] with tf.Graph().as_default(): tf.logging.set_verbosity(tf.logging.INFO) train_x, train_t = convert_data_to_tensors(train_x, train_t) val_x, val_t = convert_data_to_tensors(val_x, val_t) # Make the model. train_logits, train_endpoints = nazoru.nazorunet( train_x, num_classes=len(kanas), conv_defs=conv_defs, dropout_keep_prob=FLAGS.dropout_keep_prob, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE, is_training=True) val_logits, val_endpoints = nazoru.nazorunet( val_x, num_classes=len(kanas), conv_defs=conv_defs, dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE, is_training=True) # Add the loss function to the graph. tf.losses.softmax_cross_entropy(train_t, train_logits) train_total_loss = tf.losses.get_total_loss() tf.summary.scalar('train_total_loss', train_total_loss) val_total_loss = tf.losses.get_total_loss() tf.summary.scalar('val_total_loss', val_total_loss) train_accuracy = tf.reduce_mean(tf.cast(tf.equal( tf.argmax(train_logits, axis=1), tf.argmax(train_t, axis=1)), tf.float32)) tf.summary.scalar('train_accuracy', train_accuracy) val_accuracy = tf.reduce_mean(tf.cast(tf.equal( tf.argmax(val_logits, axis=1), tf.argmax(val_t, axis=1)), tf.float32)) tf.summary.scalar('val_accuracy', val_accuracy) # Specify the optimizer and create the train op: optimizer = tf.train.AdamOptimizer(learning_rate=0.005) train_op = slim.learning.create_train_op(train_total_loss, optimizer) def train_step_fn(sess, *args, **kwargs): total_loss, should_stop = train_step(sess, *args, **kwargs) if train_step_fn.step % FLAGS.n_steps_to_log == 0: val_acc = sess.run(val_accuracy) tf_logging.info('step: %d, validation accuracy: %.3f' % ( train_step_fn.step, val_acc)) train_step_fn.step += 1 return [total_loss, should_stop] train_step_fn.step = 0 # Run the training inside a session. final_loss = slim.learning.train( train_op, logdir=FLAGS.checkpoint_dir, number_of_steps=FLAGS.n_train_steps, train_step_fn=train_step_fn, save_summaries_secs=5, log_every_n_steps=FLAGS.n_steps_to_log) save_inference_graph(checkpoint_dir=FLAGS.checkpoint_dir, input_shape=train_x.shape[1:], num_classes=len(kanas), conv_defs=conv_defs, output_graph=FLAGS.output_graph, optimized_output_graph=FLAGS.optimized_output_graph) export_saved_model_from_pb(saved_model_dir=FLAGS.saved_model_dir, graph_name=FLAGS.optimized_output_graph) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--image_size', type=int, default=16, help='Size of the square canvas to render input strokes.' ) parser.add_argument( '--stroke_width', type=int, default=2, help='Stroke width of rendered strokes.' ) parser.add_argument( '--n_train_steps', type=int, default=500, help='Number of training steps to run.' ) parser.add_argument( '--n_steps_to_log', type=int, default=10, help='How often to print log during traing.' ) parser.add_argument( '--checkpoint_dir', type=str, default=os.path.join(os.sep, 'tmp', 'nazorunet'), help='Where to save checkpoint files.' ) parser.add_argument( '--output_graph', type=str, default='nazoru.pb', help='Where to save the trained graph.' ) parser.add_argument( '--optimized_output_graph', type=str, default='optimized_nazoru.pb', help='Where to save the trained graph optimized for inference.' ) parser.add_argument( '--saved_model_dir', type=str, default='nazoru_saved_model', help='Where to save the exported graph.' ) parser.add_argument( '--dropout_keep_prob', type=float, default=0.8, help='The percentage of activation values that are retained in dropout.' ) parser.add_argument( 'stroke_data', type=str, help='Path to zipped stroke data to input. You can download the ' + 'default stroke data at ' + 'https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip.' ) parser.add_argument( '--no_directional_feature', action='store_false', help='Not to use directional feature.' ) parser.add_argument( '--no_temporal_feature', action='store_false', help='Not to use temporal feature.' ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)