mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-08 16:53:28 +03:00
Add nazoru-input in mozc-nazoru/
This commit is contained in:
114
mozc-nazoru/.gitignore
vendored
Normal file
114
mozc-nazoru/.gitignore
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
.static_storage/
|
||||
.media/
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# macOS filesystem
|
||||
.DS_Store
|
||||
|
||||
# Uncompressed training data
|
||||
data/strokes.ndjson
|
||||
nazoru_saved_model/
|
||||
nazoru.pb
|
||||
optimized_nazoru.pb
|
||||
256
mozc-nazoru/README.rst
Normal file
256
mozc-nazoru/README.rst
Normal file
@@ -0,0 +1,256 @@
|
||||
Gboard Physical Handwriting Version
|
||||
===================================
|
||||
|
||||
Gboard Physical Handwriting Version is a device which translates your
|
||||
scribble on your keyboard into a character. You can make your own Gboard
|
||||
Physical Handwriting Version by printing your own printed circuit board
|
||||
(PCB). Also, you can train your own model to recognize a customized set
|
||||
of characters. This repository provides circuit diagram, the board
|
||||
layout and software to recognize your stroke over the keyboard as a
|
||||
character.
|
||||
|
||||
Software Usage
|
||||
--------------
|
||||
|
||||
Input Characters
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ pip install .
|
||||
$ nazoru-input
|
||||
|
||||
By running the commands above, you can make your own machine into an
|
||||
input device which accepts scribbles on the connected keyboard and send
|
||||
characters via bluetooth. At the beginning, this script scans
|
||||
connected keyboards and starts listening to inputs from one of the
|
||||
keyboards. Then it translates a sequence of keydowns into a predicted
|
||||
character considering pressed timings, and send the character to the
|
||||
target device paired by bluetooth.
|
||||
|
||||
If you want to try it for development, you can use ``-e`` option.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ pip install -e .
|
||||
$ nazoru-input
|
||||
|
||||
Training Model
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ pip install .
|
||||
$ nazoru-training ./data/strokes.zip
|
||||
|
||||
We have a script to generate a trained model which recognizes input
|
||||
characters from scribbles. This script renders input stroke data into
|
||||
images to extract useful features for prediction considering position of
|
||||
the key and timing of keyboard events. Rendered images are fed into the
|
||||
neural network model and the optimizer tunes the model to fit the data.
|
||||
Once the training is done, the script outputs the trained graph, which
|
||||
you can use for your own device. In the case where you install
|
||||
``nazoru-training`` from pip, you can find ``strokes.zip`` at here:
|
||||
https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip
|
||||
|
||||
You can change some configurations by passing command line flags (e.g.
|
||||
path to the input/output files, hyper-parameters). Run
|
||||
``nazoru-training --help`` for details.
|
||||
|
||||
Hardware Setup
|
||||
--------------
|
||||
|
||||
Printed Circuit Board
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
Gboard Physical Handwriting Version uses Raspberry Pi Zero for the
|
||||
keyboard input recognition and RN42 module for Bluetooth connection to
|
||||
your laptop. You can check the wiring at ``board/schematic.png``. Also,
|
||||
the original CAD data in EAGLE format is available
|
||||
(``board/nazoru-stack.sch`` and ``board/nazoru-stack.brd``). The board
|
||||
has non-connected pads and connectors for SPI and I2C. The connector
|
||||
itself should be compatible with other Raspberry Pi, but we tested it
|
||||
only on Raspberry Pi Zero W.
|
||||
|
||||
.. image:: ./board/schematic.png
|
||||
:width: 1000px
|
||||
|
||||
Raspberry Pi Setup
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
**Step 0 - Prepare your Raspberry Pi**
|
||||
|
||||
Please prepare your Raspberry Pi, SD card initialized by RASPBIAN
|
||||
image, and RN42 module. Connect your Raspberry Pi with RN42 as the
|
||||
schematic shows. Please make sure you can have access to the internet
|
||||
and also it has enough disk space to install packages on the following
|
||||
steps.
|
||||
|
||||
**Step 1 - Setup UART to RN42**
|
||||
|
||||
If you try it on Raspberry Pi Zero W or Raspberry Pi 3, you need to
|
||||
have additional settings for the serial communication because they
|
||||
equipped a wireless module connected by the UART. See details at `an
|
||||
official document
|
||||
<https://www.raspberrypi.org/documentation/configuration/uart.md>`_.
|
||||
In short, you need to add ``enable_uart=1`` to ``/boot/config.txt`` on
|
||||
your Raspberry Pi.
|
||||
|
||||
**Step 2 - Initial setup for RN42**
|
||||
|
||||
You need to write your initial setup to RN42. At first, install screen
|
||||
and open ``/dev/serial0`` for configuration.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ sudo apt install screen
|
||||
$ sudo screen /dev/serial0 115200
|
||||
|
||||
After that, please type the following commands. Note that you need to
|
||||
type ENTER after input commands. For example, please type ``$$$``
|
||||
and ENTER to execute ``$$$`` command.
|
||||
|
||||
1. ``$$$`` : Get into the command mode. The green LED will blink
|
||||
faster.
|
||||
2. ``+`` : You can see what you type.
|
||||
3. ``SD,0540`` : Set the device class to keyboard.
|
||||
4. ``S~,6`` : Set the profile to HID.
|
||||
5. ``SH,0200`` : Set the HID flag to keyboard.
|
||||
6. ``SN,nazoru-input`` : Set the device name as nazoru-input. You
|
||||
can name it as you want.
|
||||
7. ``R,1`` : Reboot RN42.
|
||||
|
||||
You can quit the screen by ``C-a k``.
|
||||
|
||||
**Step 3 - Download and install nazoru-input**
|
||||
|
||||
We provide a service file at ``data/nazoru.service`` to launch
|
||||
``nazoru-input`` when booting. You can install it by uncomment
|
||||
``data_files`` entry in ``setup.py``. Also, before installing this
|
||||
package, We'd strongly recommend you to install some package from apt
|
||||
repository as follows, so that you can install pre-built packages.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ sudo apt install git python-pip python-numpy python-cairocffi \
|
||||
python-h5py python-imaging python-scipy libblas-dev liblapack-dev \
|
||||
python-dev libatlas-base-dev gfortran python-setuptools \
|
||||
python-html5lib
|
||||
$ sudo pip install http://ci.tensorflow.org/view/Nightly/job/nightly-pi-zero/219/artifact/output-artifacts/tensorflow-1.6.0-cp27-none-any.whl
|
||||
$ git clone https://github.com/google/mozc-devices
|
||||
$ cd mozc-devices/mozc-nazoru
|
||||
$ sudo pip install . # If you want to develop nazoru-input, please use 'pip install -e .' instead.
|
||||
|
||||
**Step 4 - Enjoy!**
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ sudo nazoru-input # If you miss sudo, nazoru-input may use a DummyBluetooth object.
|
||||
|
||||
Training Data Format
|
||||
--------------------
|
||||
|
||||
We are providing the raw training data at ``data/strokes.zip``. Once you
|
||||
uncompress the zip file, you will get a ``.ndjson`` file which contains
|
||||
all entries (we call them **strokes**) we have used for training.
|
||||
|
||||
Each stroke entry contains the following field:
|
||||
|
||||
+----------+-----------+-------------------------------------------+
|
||||
| Key | Type | Description |
|
||||
+==========+===========+===========================================+
|
||||
| id | integer | A unique identifier across all strokes. |
|
||||
+----------+-----------+-------------------------------------------+
|
||||
| writer | string | A unique identifier of writer. |
|
||||
+----------+-----------+-------------------------------------------+
|
||||
| kana | string | Label of the character drawn. |
|
||||
+----------+-----------+-------------------------------------------+
|
||||
| events | list | List of keyboard events. |
|
||||
+----------+-----------+-------------------------------------------+
|
||||
|
||||
Each event is a 3-tuple of (``key``, ``event type``, ``time``). ``key``
|
||||
describes the key on which the event happened. ``event type`` describes
|
||||
what type of event happened. It should be "down" (keydown) or "up"
|
||||
(keyup). ``time`` describes the consumed time until the event is fired
|
||||
in millisecond.
|
||||
|
||||
For example, the entry below denotes a stoke of "ほ
|
||||
(\\u307b)" accompanied with a sequence of keyboard events
|
||||
starting from the keydown event on "t" and ending at the keyup event on
|
||||
"l" which was fired 1.005 seconds later after it started recording.
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"id": 5788999721418752,
|
||||
"writer": "ffb0dac6b8be3faa81da320e29a2ba72",
|
||||
"kana": "\u307b",
|
||||
"events": [
|
||||
["t", "down", 0],
|
||||
["g", "down", 40],
|
||||
...
|
||||
["l", "down", 966],
|
||||
["l", "up", 1005]
|
||||
]
|
||||
}
|
||||
|
||||
You can also prepare your own dataset in ``.ndjson`` format and rebuild
|
||||
the model on it. The list of kanas to recognize is in
|
||||
``src/nazoru/lib.py``. You can update that if you want to modify the set
|
||||
of characters.
|
||||
|
||||
Network Structure
|
||||
-----------------
|
||||
|
||||
Data Preprocessing
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Each stroke entry is rendered to a square image before any training
|
||||
runs. The script (``nazoru-training``) renders strokes in various ways
|
||||
to extract useful features. Our default settings extract 10 features
|
||||
from each stroke entry: 8 directional features and 2 temporal features
|
||||
on 16x16 square canvas; this means that the input shape is 16x16x10 by
|
||||
default.
|
||||
|
||||
Convolutional Network
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Rendered inputs are fed into a convolutional neural network designed for
|
||||
this task. Body structure looks like:
|
||||
|
||||
- Convolutional layer (kernel size: 3x3, filter size: 32, stride: 2,
|
||||
activation: Relu)
|
||||
- Separatable convolutional layer (kernel size: 3x3, filter size: 64,
|
||||
stride: 1, activation: Relu)
|
||||
- Separatable convolutional layer (kernel size: 3x3, filter size: 128,
|
||||
stride: 2, activation: Relu)
|
||||
- Separatable convolutional layer (kernel size: 3x3, filter size: 128,
|
||||
stride: 1, activation: Relu)
|
||||
|
||||
For more details about the separatable convolutional layers, please
|
||||
refer to `MobileNet <https://arxiv.org/abs/1704.04861>`__ architecture.
|
||||
|
||||
Authors
|
||||
-------
|
||||
|
||||
Machine Learning:
|
||||
|
||||
Shuhei Iitsuka <tushuhei@google.com>
|
||||
|
||||
Hardwares, system setups:
|
||||
|
||||
Makoto Shimazu <shimazu@google.com>
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
not use this file except in compliance with the License. You may obtain
|
||||
a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
77
mozc-nazoru/bin/nazoru-input
Executable file
77
mozc-nazoru/bin/nazoru-input
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from nazoru.led import LED_BLUE, LED_RED
|
||||
LED_RED.blink(1)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from nazoru import get_default_graph_path
|
||||
from nazoru.core import create_keyboard_recorder, Bluetooth, NazoruPredictor
|
||||
|
||||
def main():
|
||||
FLAGS = None
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-g', '--graph',
|
||||
type=str,
|
||||
default=get_default_graph_path(),
|
||||
help='Path to a trained model which is generated by ' +
|
||||
'nazoru-training.')
|
||||
parser.add_argument('-v', '--verbose', action='store_true')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
LED_RED.blink(0.3)
|
||||
|
||||
bt_connection = Bluetooth()
|
||||
try:
|
||||
recorder = create_keyboard_recorder(verbose=FLAGS.verbose)
|
||||
except IOError as e:
|
||||
LED_RED.off()
|
||||
LED_BLUE.off()
|
||||
raise e
|
||||
predictor = NazoruPredictor(FLAGS.graph)
|
||||
|
||||
LED_RED.off()
|
||||
LED_BLUE.blink(1)
|
||||
|
||||
print('Ready. Please input your scrrible.')
|
||||
while True:
|
||||
data, command = recorder.record()
|
||||
if command is not None:
|
||||
print('command: %s' % command)
|
||||
bt_connection.command(command)
|
||||
continue
|
||||
if data is None:
|
||||
print('done.')
|
||||
break
|
||||
|
||||
LED_RED.on()
|
||||
result = predictor.predict_top_n(data, 5)
|
||||
LED_RED.off()
|
||||
|
||||
print('\n=== RESULTS ===')
|
||||
for item in result:
|
||||
print(u'%s (%s): %.5f' % (item[0], item[1], item[2]))
|
||||
print('===============\n')
|
||||
|
||||
most_likely_result = result[0]
|
||||
print(u'%s (%s)' % (most_likely_result[0], most_likely_result[1]))
|
||||
bt_connection.send(most_likely_result[1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
279
mozc-nazoru/bin/nazoru-training
Executable file
279
mozc-nazoru/bin/nazoru-training
Executable file
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib.slim.python.slim.learning import train_step
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import optimize_for_inference_lib
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import nazoru.core as nazoru
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import zipfile
|
||||
|
||||
FLAGS = None
|
||||
|
||||
def load_data(kanas):
|
||||
|
||||
def get_ndjson_path(zip_path):
|
||||
dir_path, filename = os.path.split(zip_path)
|
||||
body, ext = os.path.splitext(filename)
|
||||
return dir_path, os.path.join(dir_path, '%s.ndjson' % (body))
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
dir_path, ndjson_path = get_ndjson_path(FLAGS.stroke_data)
|
||||
if not os.path.exists(ndjson_path):
|
||||
with zipfile.ZipFile(FLAGS.stroke_data, 'r') as f:
|
||||
f.extractall(dir_path)
|
||||
with open(ndjson_path) as f:
|
||||
for line in f.readlines():
|
||||
line = json.loads(line)
|
||||
if not line['kana'].lower() in kanas: continue
|
||||
keydowns = [(e[0], e[2]) for e in line['events'] if e[1] == 'down']
|
||||
# TODO (tushuhei) Relax this condition to accept shorter strokes.
|
||||
if len(keydowns) < 5: continue # Ignore if too short.
|
||||
data.append(keydowns)
|
||||
labels.append(kanas.index(line['kana'].lower()))
|
||||
labels = np.eye(len(kanas))[labels] # Idiom to one-hot encode.
|
||||
return data, labels
|
||||
|
||||
|
||||
def convert_data_to_tensors(x, y):
|
||||
inputs = tf.constant(x, dtype=tf.float32)
|
||||
inputs.set_shape([None] + list(x[0].shape))
|
||||
outputs = tf.constant(y)
|
||||
outputs.set_shape([None] + list(y[0].shape))
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def save_inference_graph(checkpoint_dir, input_shape, num_classes, conv_defs,
|
||||
output_graph, optimized_output_graph,
|
||||
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
||||
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
||||
with tf.Graph().as_default():
|
||||
inputs = tf.placeholder(
|
||||
tf.float32, shape=[1] + list(input_shape), name=input_node_name)
|
||||
outputs, _ = nazoru.nazorunet(
|
||||
inputs, num_classes=num_classes, conv_defs=conv_defs, is_training=False,
|
||||
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE)
|
||||
sv = tf.train.Supervisor(logdir=checkpoint_dir)
|
||||
with sv.managed_session() as sess:
|
||||
output_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph.as_graph_def(), [output_node_name])
|
||||
# TODO (tushuhei) Maybe we don't need to export unoptimized graph.
|
||||
with gfile.FastGFile(output_graph, 'wb') as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
|
||||
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
|
||||
output_graph_def, [input_node_name], [output_node_name],
|
||||
dtypes.float32.as_datatype_enum)
|
||||
|
||||
with gfile.FastGFile(optimized_output_graph, 'wb') as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
|
||||
def export_saved_model_from_pb(saved_model_dir, graph_name,
|
||||
input_node_name=nazoru.lib.INPUT_NODE_NAME,
|
||||
output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
|
||||
with tf.gfile.GFile(graph_name, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sigs = {}
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
g = tf.get_default_graph()
|
||||
|
||||
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
||||
tf.saved_model.signature_def_utils.predict_signature_def(
|
||||
{input_node_name: g.get_tensor_by_name('%s:0' % (input_node_name))},
|
||||
{output_node_name: g.get_tensor_by_name('%s:0' % (output_node_name))
|
||||
})
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, [tag_constants.SERVING], signature_def_map=sigs)
|
||||
|
||||
builder.save()
|
||||
|
||||
|
||||
def main(_):
|
||||
kanas = nazoru.lib.KANAS
|
||||
keydown_strokes, labels = load_data(kanas)
|
||||
x = np.array([
|
||||
nazoru.lib.keydowns2image(
|
||||
keydown_stroke,
|
||||
directional_feature=not FLAGS.no_directional_feature,
|
||||
temporal_feature=not FLAGS.no_temporal_feature,
|
||||
scale=FLAGS.image_size,
|
||||
stroke_width=FLAGS.stroke_width)
|
||||
for keydown_stroke in keydown_strokes])
|
||||
train_x, train_t, val_x, val_t, test_x, test_t = nazoru.lib.split_data(
|
||||
x, labels, 0.2, 0.2)
|
||||
conv_defs = [
|
||||
nazoru.Conv(kernel=[3, 3], stride=2, depth=32),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
|
||||
nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
|
||||
]
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
train_x, train_t = convert_data_to_tensors(train_x, train_t)
|
||||
val_x, val_t = convert_data_to_tensors(val_x, val_t)
|
||||
|
||||
# Make the model.
|
||||
train_logits, train_endpoints = nazoru.nazorunet(
|
||||
train_x, num_classes=len(kanas), conv_defs=conv_defs,
|
||||
dropout_keep_prob=FLAGS.dropout_keep_prob, scope=nazoru.lib.SCOPE,
|
||||
reuse=tf.AUTO_REUSE, is_training=True)
|
||||
val_logits, val_endpoints = nazoru.nazorunet(
|
||||
val_x, num_classes=len(kanas), conv_defs=conv_defs,
|
||||
dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE,
|
||||
is_training=True)
|
||||
|
||||
# Add the loss function to the graph.
|
||||
tf.losses.softmax_cross_entropy(train_t, train_logits)
|
||||
train_total_loss = tf.losses.get_total_loss()
|
||||
tf.summary.scalar('train_total_loss', train_total_loss)
|
||||
|
||||
val_total_loss = tf.losses.get_total_loss()
|
||||
tf.summary.scalar('val_total_loss', val_total_loss)
|
||||
|
||||
train_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
||||
tf.argmax(train_logits, axis=1), tf.argmax(train_t, axis=1)), tf.float32))
|
||||
tf.summary.scalar('train_accuracy', train_accuracy)
|
||||
|
||||
val_accuracy = tf.reduce_mean(tf.cast(tf.equal(
|
||||
tf.argmax(val_logits, axis=1), tf.argmax(val_t, axis=1)), tf.float32))
|
||||
tf.summary.scalar('val_accuracy', val_accuracy)
|
||||
|
||||
# Specify the optimizer and create the train op:
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=0.005)
|
||||
train_op = slim.learning.create_train_op(train_total_loss, optimizer)
|
||||
|
||||
def train_step_fn(sess, *args, **kwargs):
|
||||
total_loss, should_stop = train_step(sess, *args, **kwargs)
|
||||
if train_step_fn.step % FLAGS.n_steps_to_log == 0:
|
||||
val_acc = sess.run(val_accuracy)
|
||||
tf_logging.info('step: %d, validation accuracy: %.3f' % (
|
||||
train_step_fn.step, val_acc))
|
||||
train_step_fn.step += 1
|
||||
return [total_loss, should_stop]
|
||||
train_step_fn.step = 0
|
||||
|
||||
# Run the training inside a session.
|
||||
final_loss = slim.learning.train(
|
||||
train_op,
|
||||
logdir=FLAGS.checkpoint_dir,
|
||||
number_of_steps=FLAGS.n_train_steps,
|
||||
train_step_fn=train_step_fn,
|
||||
save_summaries_secs=5,
|
||||
log_every_n_steps=FLAGS.n_steps_to_log)
|
||||
|
||||
save_inference_graph(checkpoint_dir=FLAGS.checkpoint_dir,
|
||||
input_shape=train_x.shape[1:], num_classes=len(kanas),
|
||||
conv_defs=conv_defs, output_graph=FLAGS.output_graph,
|
||||
optimized_output_graph=FLAGS.optimized_output_graph)
|
||||
export_saved_model_from_pb(saved_model_dir=FLAGS.saved_model_dir,
|
||||
graph_name=FLAGS.optimized_output_graph)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--image_size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Size of the square canvas to render input strokes.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stroke_width',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Stroke width of rendered strokes.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n_train_steps',
|
||||
type=int,
|
||||
default=500,
|
||||
help='Number of training steps to run.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n_steps_to_log',
|
||||
type=int,
|
||||
default=10,
|
||||
help='How often to print log during traing.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
type=str,
|
||||
default=os.path.join(os.sep, 'tmp', 'nazorunet'),
|
||||
help='Where to save checkpoint files.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_graph',
|
||||
type=str,
|
||||
default='nazoru.pb',
|
||||
help='Where to save the trained graph.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--optimized_output_graph',
|
||||
type=str,
|
||||
default='optimized_nazoru.pb',
|
||||
help='Where to save the trained graph optimized for inference.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--saved_model_dir',
|
||||
type=str,
|
||||
default='nazoru_saved_model',
|
||||
help='Where to save the exported graph.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dropout_keep_prob',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='The percentage of activation values that are retained in dropout.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'stroke_data',
|
||||
type=str,
|
||||
help='Path to zipped stroke data to input. You can download the ' +
|
||||
'default stroke data at ' +
|
||||
'https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no_directional_feature',
|
||||
action='store_false',
|
||||
help='Not to use directional feature.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no_temporal_feature',
|
||||
action='store_false',
|
||||
help='Not to use temporal feature.'
|
||||
)
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
1111
mozc-nazoru/board/nazoru-stack.brd
Normal file
1111
mozc-nazoru/board/nazoru-stack.brd
Normal file
File diff suppressed because it is too large
Load Diff
2109
mozc-nazoru/board/nazoru-stack.sch
Normal file
2109
mozc-nazoru/board/nazoru-stack.sch
Normal file
File diff suppressed because it is too large
Load Diff
BIN
mozc-nazoru/board/schematic.png
Normal file
BIN
mozc-nazoru/board/schematic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 215 KiB |
15
mozc-nazoru/data/nazoru.service
Normal file
15
mozc-nazoru/data/nazoru.service
Normal file
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=nazoru-input daemon
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
Environment=PYTHONIOENCODING=utf-8
|
||||
ExecStart=/usr/bin/env nazoru-input
|
||||
StandardOutput=syslog
|
||||
StandardError=syslog
|
||||
SyslogIdentifier=nazoru-input
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
BIN
mozc-nazoru/data/strokes.zip
Normal file
BIN
mozc-nazoru/data/strokes.zip
Normal file
Binary file not shown.
70
mozc-nazoru/setup.py
Executable file
70
mozc-nazoru/setup.py
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
def read_file(name):
|
||||
with open(os.path.join(os.path.dirname(__file__), name), 'r') as f:
|
||||
return f.read().strip()
|
||||
|
||||
setup(
|
||||
name='nazoru-input',
|
||||
version='0.1',
|
||||
author='Makoto Shimazu',
|
||||
author_email='shimazu@google.com',
|
||||
url='https://landing.google.com/tegaki',
|
||||
description='Package for Gboard Physical Handwriting Version',
|
||||
long_description=read_file('README.rst'),
|
||||
license='Apache',
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Environment :: Console',
|
||||
'Environment :: No Input/Output (Daemon)',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
'Topic :: Utilities',
|
||||
],
|
||||
packages=find_packages('src'),
|
||||
package_dir={'': 'src'},
|
||||
package_data={
|
||||
'nazoru': ['data/optimized_nazoru.pb']
|
||||
},
|
||||
scripts=[
|
||||
'bin/nazoru-input',
|
||||
'bin/nazoru-training'
|
||||
],
|
||||
|
||||
# For installing the nazoru_input as a service of systemd. Please uncomment
|
||||
# the following |data_files| if you want to install nazoru.service.
|
||||
# data_files=[('/etc/systemd/system', ['data/nazoru.service'])],
|
||||
|
||||
install_requires=[
|
||||
'np_utils',
|
||||
'cairocffi',
|
||||
'numpy',
|
||||
'h5py',
|
||||
'pillow',
|
||||
'tensorflow',
|
||||
'enum34;python_version<"3.4"',
|
||||
'pyserial',
|
||||
'evdev;platform_system=="Linux"'
|
||||
]
|
||||
)
|
||||
6
mozc-nazoru/src/nazoru/__init__.py
Normal file
6
mozc-nazoru/src/nazoru/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
def get_default_graph_path():
|
||||
import os
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(script_dir, 'data', 'optimized_nazoru.pb')
|
||||
57
mozc-nazoru/src/nazoru/bluetooth.py
Normal file
57
mozc-nazoru/src/nazoru/bluetooth.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import serial
|
||||
import struct
|
||||
|
||||
class Bluetooth():
|
||||
DEVICE_FILE = '/dev/serial0'
|
||||
BAUDRATE = 115200
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self._conn = serial.Serial(self.DEVICE_FILE, self.BAUDRATE)
|
||||
self._dummy = False
|
||||
print('Bluetooth')
|
||||
except serial.SerialException:
|
||||
self._conn = None
|
||||
self._dummy = True
|
||||
print('Dummy Bluetooth')
|
||||
|
||||
def send(self, string):
|
||||
"""Send |string| as a series of characters. |string| should be
|
||||
alphabets, numbers and symbols which can be typed from your keyboard."""
|
||||
if self._dummy:
|
||||
print('bluetooth: {}'.format(string))
|
||||
return
|
||||
self._conn.write(string)
|
||||
|
||||
# See http://ww1.microchip.com/downloads/en/DeviceDoc/bluetooth_cr_UG-v1.0r.pdf
|
||||
# for detail.
|
||||
UART_CODES = {
|
||||
'KEY_RIGHT': 7,
|
||||
'KEY_BACKSPACE': 8,
|
||||
'KEY_ENTER': 10,
|
||||
'KEY_LEFT': 11,
|
||||
'KEY_DOWN': 12,
|
||||
'KEY_UP': 14,
|
||||
}
|
||||
|
||||
def command(self, cmd):
|
||||
if cmd not in self.UART_CODES:
|
||||
print('Unknown Command: {}'.format(cmd))
|
||||
return
|
||||
self.send(struct.pack('b', self.UART_CODES[cmd]))
|
||||
26
mozc-nazoru/src/nazoru/core.py
Normal file
26
mozc-nazoru/src/nazoru/core.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from . import lib
|
||||
from .bluetooth import Bluetooth
|
||||
from .keyboard_recorder import create_keyboard_recorder
|
||||
from .nazorunet import nazorunet
|
||||
from .nazorunet import Conv
|
||||
from .nazorunet import DepthSepConv
|
||||
from .predictor import NazoruPredictor
|
||||
|
||||
BIN
mozc-nazoru/src/nazoru/data/optimized_nazoru.pb
Normal file
BIN
mozc-nazoru/src/nazoru/data/optimized_nazoru.pb
Normal file
Binary file not shown.
254
mozc-nazoru/src/nazoru/keyboard_recorder.py
Normal file
254
mozc-nazoru/src/nazoru/keyboard_recorder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import fcntl
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import termios
|
||||
|
||||
try:
|
||||
import evdev
|
||||
except ImportError:
|
||||
evdev = None
|
||||
|
||||
def _set_raw_mode_stdin():
|
||||
fno = sys.stdin.fileno()
|
||||
attr_old = termios.tcgetattr(fno)
|
||||
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
|
||||
|
||||
attr_new = termios.tcgetattr(fno)
|
||||
attr_new[3] = attr_new[3] & ~termios.ECHO & ~termios.ICANON
|
||||
termios.tcsetattr(fno, termios.TCSADRAIN, attr_new)
|
||||
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
|
||||
|
||||
def reset_raw_mode():
|
||||
termios.tcsetattr(fno, termios.TCSANOW, attr_old)
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
|
||||
|
||||
return reset_raw_mode
|
||||
|
||||
def _set_raw_mode_general(f):
|
||||
fno = f.fileno()
|
||||
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
|
||||
|
||||
def reset_raw_mode():
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
|
||||
|
||||
return reset_raw_mode
|
||||
|
||||
class KeyboardRecorder():
|
||||
def __init__(self, verbose=False):
|
||||
self._verbose = verbose
|
||||
|
||||
def record(self):
|
||||
return (None, None)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
# TODO(shimazu): Use logger.
|
||||
if self._verbose:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
class KeyboardRecorderFromConsole(KeyboardRecorder):
|
||||
def __init__(self, verbose=False):
|
||||
KeyboardRecorder.__init__(self, verbose)
|
||||
self.log('Input from console')
|
||||
|
||||
def record(self):
|
||||
"""
|
||||
Returns a tuple of |data| and |command|.
|
||||
|data|: an array of tuples of keys and time from the first character.
|
||||
|command|: None
|
||||
"""
|
||||
recording = False
|
||||
start_time = None
|
||||
last_time = None
|
||||
wait_seconds = 2
|
||||
data = []
|
||||
reset_raw_mode = _set_raw_mode_stdin()
|
||||
try:
|
||||
while 1:
|
||||
try:
|
||||
key = sys.stdin.read(1)
|
||||
except IOError:
|
||||
key = None
|
||||
finally:
|
||||
now = datetime.datetime.now()
|
||||
if key == '\n':
|
||||
return (None, None)
|
||||
elif key:
|
||||
if not recording:
|
||||
recording = True
|
||||
start_time = datetime.datetime.now()
|
||||
elapsed_time = now - start_time
|
||||
elapsed_ms = int(elapsed_time.total_seconds() * 1000)
|
||||
last_time = now
|
||||
data.append((key, elapsed_ms))
|
||||
self.log(key, elapsed_ms)
|
||||
if last_time and (now - last_time).total_seconds() > wait_seconds:
|
||||
break
|
||||
finally:
|
||||
reset_raw_mode()
|
||||
return (data, None)
|
||||
|
||||
class KeyboardRecorderFromEvdev(KeyboardRecorder):
|
||||
KEYS = {
|
||||
2: '1',
|
||||
3: '2',
|
||||
4: '3',
|
||||
5: '4',
|
||||
6: '5',
|
||||
7: '6',
|
||||
8: '7',
|
||||
9: '8',
|
||||
10: '9',
|
||||
11: '0',
|
||||
12: '-',
|
||||
13: '=',
|
||||
16: 'q',
|
||||
17: 'w',
|
||||
18: 'e',
|
||||
19: 'r',
|
||||
20: 't',
|
||||
21: 'y',
|
||||
22: 'u',
|
||||
23: 'i',
|
||||
24: 'o',
|
||||
25: 'p',
|
||||
26: '[',
|
||||
27: ']',
|
||||
30: 'a',
|
||||
31: 's',
|
||||
32: 'd',
|
||||
33: 'f',
|
||||
34: 'g',
|
||||
35: 'h',
|
||||
36: 'j',
|
||||
37: 'k',
|
||||
38: 'l',
|
||||
39: ';',
|
||||
40: '\'',
|
||||
43: '\\',
|
||||
44: 'z',
|
||||
45: 'x',
|
||||
46: 'c',
|
||||
47: 'v',
|
||||
48: 'b',
|
||||
49: 'n',
|
||||
50: 'm',
|
||||
51: ',',
|
||||
52: '.',
|
||||
53: '/'
|
||||
}
|
||||
WAIT_SECONDS = 2
|
||||
|
||||
def __init__(self, verbose=False):
|
||||
if evdev is None:
|
||||
raise TypeError('KeyboardRecorderFromEvdev needs to be used on Linux ' +
|
||||
'(or POSIX compatible) system. Instead, You can try it ' +
|
||||
'on your console.')
|
||||
KeyboardRecorder.__init__(self, verbose)
|
||||
self.log('Input from evdev')
|
||||
keyboards = []
|
||||
ecode_ev_key = evdev.ecodes.ecodes['EV_KEY']
|
||||
ecode_key_esc = evdev.ecodes.ecodes['KEY_ESC']
|
||||
for device in [evdev.InputDevice(fn) for fn in evdev.list_devices()]:
|
||||
# TODO(shimazu): Consider more solid criteria to get 'regular' keyboards.
|
||||
if ecode_ev_key in device.capabilities() and \
|
||||
ecode_key_esc in device.capabilities()[ecode_ev_key]:
|
||||
keyboards.append(device)
|
||||
if len(keyboards) == 0:
|
||||
raise IOError('No keyboard found.')
|
||||
self._keyboards = keyboards
|
||||
for keyboard in keyboards:
|
||||
self.log('----')
|
||||
self.log(keyboard)
|
||||
self.log('name: {0}'.format(keyboard.name))
|
||||
self.log('phys: {0}'.format(keyboard.phys))
|
||||
self.log('repeat: {0}'.format(keyboard.repeat))
|
||||
self.log('info: {0}'.format(keyboard.info))
|
||||
self.log(keyboard.capabilities(verbose=True))
|
||||
|
||||
def record(self):
|
||||
"""
|
||||
Returns a tuple of |data| and |command|.
|
||||
|data|: an array of tuples of keys and time from the first character. None
|
||||
if the input is non-alphabet/numeric/symbols like ENTER, arrows etc.
|
||||
|command|: Commands like "KEY_ENTER" or None if |data| is valid.
|
||||
"""
|
||||
start_time = None
|
||||
last_time = None
|
||||
data = []
|
||||
while True:
|
||||
# TODO(shimazu): Check inputs from all keyboards.
|
||||
event = self._keyboards[0].read_one()
|
||||
now = datetime.datetime.now()
|
||||
if last_time and (now - last_time).total_seconds() > self.WAIT_SECONDS:
|
||||
break
|
||||
if event is None:
|
||||
continue
|
||||
name = evdev.ecodes.bytype[event.type][event.code]
|
||||
ev_type = evdev.ecodes.EV[event.type]
|
||||
if ev_type != 'EV_KEY':
|
||||
continue
|
||||
# Keyboard input
|
||||
self.log('----')
|
||||
self.log(event)
|
||||
self.log('name: {}'.format(name))
|
||||
self.log('type: {}'.format(ev_type))
|
||||
# Check if the event is from releasing the button
|
||||
if event.value == 0:
|
||||
continue
|
||||
|
||||
# It may be a non-alphabet/numeric/symbol key. Return it as a command.
|
||||
if event.code not in self.KEYS:
|
||||
if start_time is not None:
|
||||
continue
|
||||
return (None, name)
|
||||
last_time = now
|
||||
if start_time is None:
|
||||
start_time = now
|
||||
elapsed_ms = int((now - start_time).total_seconds() * 1000)
|
||||
data.append((self.KEYS[event.code], elapsed_ms))
|
||||
return (data, None)
|
||||
|
||||
InputSource = Enum('InputSource', 'EVDEV CONSOLE')
|
||||
|
||||
def create_keyboard_recorder(verbose=False, source=None):
|
||||
"""Creates KeyboardRecorder.
|
||||
|
||||
Args:
|
||||
verbose: Print the detail of input when it's true.
|
||||
source: InputSource.EVDEV, InputSource.CONSOLE or None (default)
|
||||
|
||||
Returns:
|
||||
recorder: Corresponding KeyboardRecorder. If |source| is None, returns
|
||||
KeyboardRecorderFromConsole when stdin is attached to a console (isatty is
|
||||
true). Otherwise, returns KeyboardRecorderFromEvdev.
|
||||
"""
|
||||
if source == InputSource.CONSOLE:
|
||||
return KeyboardRecorderFromConsole(verbose=verbose)
|
||||
if source == InputSource.EVDEV:
|
||||
return KeyboardRecorderFromEvdev(verbose=verbose)
|
||||
if sys.__stdin__.isatty():
|
||||
return KeyboardRecorderFromConsole(verbose=verbose)
|
||||
return KeyboardRecorderFromEvdev(verbose=verbose)
|
||||
82
mozc-nazoru/src/nazoru/led.py
Normal file
82
mozc-nazoru/src/nazoru/led.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class LEDBase():
|
||||
def __init__(self, pin):
|
||||
self._pin = pin
|
||||
def on(self):
|
||||
pass
|
||||
def off(self):
|
||||
pass
|
||||
def blink(self, interval):
|
||||
pass
|
||||
|
||||
try:
|
||||
import RPi.GPIO as GPIO
|
||||
import threading
|
||||
|
||||
class LED(LEDBase):
|
||||
ON = 'ON'
|
||||
OFF = 'OFF'
|
||||
|
||||
def __init__(self, pin):
|
||||
GPIO.setmode(GPIO.BOARD)
|
||||
self._pin = pin
|
||||
self._lock = threading.Lock()
|
||||
self._timer = None
|
||||
GPIO.setup(pin, GPIO.OUT)
|
||||
self.off()
|
||||
|
||||
def on(self):
|
||||
with self._lock:
|
||||
self._state = self.ON
|
||||
GPIO.output(self._pin, False)
|
||||
self._ensure_stop_timer()
|
||||
|
||||
def off(self):
|
||||
with self._lock:
|
||||
self._state = self.OFF
|
||||
GPIO.output(self._pin, True)
|
||||
self._ensure_stop_timer()
|
||||
|
||||
def blink(self, interval):
|
||||
self._ensure_stop_timer()
|
||||
def toggle():
|
||||
self._timer = None
|
||||
if self._state == self.ON:
|
||||
self.off()
|
||||
else:
|
||||
self.on()
|
||||
self._timer = threading.Timer(interval, toggle)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
toggle()
|
||||
|
||||
def _ensure_stop_timer(self):
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
class LED(LEDBase):
|
||||
pass
|
||||
|
||||
LED_BLUE = LED(38)
|
||||
LED_RED = LED(40)
|
||||
410
mozc-nazoru/src/nazoru/lib.py
Normal file
410
mozc-nazoru/src/nazoru/lib.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Nazoru input library.
|
||||
This is a collection of methods to preprocess input stroke data before any
|
||||
training starts.
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
import cairocffi as cairo
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
SCOPE = 'Nazorunet'
|
||||
INPUT_NODE_NAME = 'inputs'
|
||||
OUTPUT_NODE_NAME = SCOPE + '/Predictions/Reshape_1'
|
||||
KANAS = (u'あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほ'
|
||||
u'まみむめもやゆよらりるれろわゐんゑを'
|
||||
u'abcdefghijklmnopqrstuvwxyz1234567890'
|
||||
u'♡ーずぐ')
|
||||
KEYS = ('a', 'i', 'u', 'e', 'o',
|
||||
'ka', 'ki', 'ku', 'ke', 'ko',
|
||||
'sa', 'si', 'su', 'se', 'so',
|
||||
'ta', 'ti', 'tu', 'te', 'to',
|
||||
'na', 'ni', 'nu', 'ne', 'no',
|
||||
'ha', 'hi', 'hu', 'he', 'ho',
|
||||
'ma', 'mi', 'mu', 'me', 'mo',
|
||||
'ya', 'yu', 'yo',
|
||||
'ra', 'ri', 'ru', 're', 'ro',
|
||||
'wa', 'wi', 'nn', 'we', 'wo',
|
||||
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
|
||||
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
|
||||
'1', '2', '3', '4', '5', '6', '7', '8', '9', '0',
|
||||
'ha-to', '-', 'zu', 'gu')
|
||||
|
||||
class KeyboardArrangement(Enum):
|
||||
"""Enum for keyboard arrangements.
|
||||
"""
|
||||
qwerty_jis = [
|
||||
u'1234567890-^¥',
|
||||
u'qwertyuiop@[',
|
||||
u'asdfghjkl;:]',
|
||||
u'zxcvbnm,./_',
|
||||
]
|
||||
|
||||
|
||||
def key2pos(key, arrangement=KeyboardArrangement.qwerty_jis.value, offset=0.5):
|
||||
"""Returns the key position.
|
||||
|
||||
Args:
|
||||
key (string): Key to get position.
|
||||
arrangement (list): Keyboard arrangement.
|
||||
offset (number): How much the keys are shifting by row.
|
||||
|
||||
Returns:
|
||||
position (tuple(number, number)): Position (x, y).
|
||||
|
||||
"""
|
||||
|
||||
for i, row in enumerate(arrangement):
|
||||
if key in row:
|
||||
y = i
|
||||
x = row.index(key) + i * offset
|
||||
return (x, y)
|
||||
return None
|
||||
|
||||
|
||||
def keydowns2points(keydowns):
|
||||
"""Translates keydowns to points.
|
||||
|
||||
Args:
|
||||
keydowns: [(key, t), ...] List of keydowns.
|
||||
|
||||
Returns:
|
||||
points: [(x, y, t), ...] List of points.
|
||||
"""
|
||||
|
||||
points = []
|
||||
for keydown in keydowns:
|
||||
pos = key2pos(keydown[0])
|
||||
if pos:
|
||||
points.append((pos[0], pos[1], keydown[1]))
|
||||
return points
|
||||
|
||||
|
||||
def normalize_x(x):
|
||||
"""Normalizes position.
|
||||
|
||||
Args:
|
||||
x (list): [[x, y, t], ...] List of points to normalize the position (x, y)
|
||||
into 0-1 range.
|
||||
|
||||
Returns:
|
||||
x (list): [[x', y', t], ...] List of points with the normalized potision
|
||||
(x', y').
|
||||
"""
|
||||
|
||||
x = np.array(x)
|
||||
max_ = np.max(x[:, :2], axis=0)
|
||||
min_ = np.min(x[:, :2], axis=0)
|
||||
x[:, :2] = (x[:, :2] - min_) / (max_ - min_)
|
||||
return x
|
||||
|
||||
|
||||
def pendown_encode(x_diff, sigma=1.6):
|
||||
"""Encodes time difference into pendown state.
|
||||
|
||||
Args:
|
||||
x_diff (list): [[dx, dy, dt], ...] List of diffs to encode.
|
||||
|
||||
Returns:
|
||||
x_diff_encoded (list): [[dx, dy, dt, pendown], ...] Encoded list of diffs.
|
||||
"""
|
||||
|
||||
thres = np.mean(x_diff[:,2]) + sigma * np.std(x_diff[:,2])
|
||||
x_diff_encoded = np.concatenate((
|
||||
x_diff,
|
||||
[[0] if dt_i > thres else [1] for dt_i in x_diff[:, 2]]
|
||||
), axis=1)
|
||||
return x_diff_encoded
|
||||
|
||||
|
||||
def surface_to_array(surface):
|
||||
"""Returns image array from cairo surface.
|
||||
|
||||
Args:
|
||||
surface: Cairo surface to translate.
|
||||
"""
|
||||
buf = BytesIO()
|
||||
surface.write_to_png(buf)
|
||||
png_string = buf.getvalue()
|
||||
im = Image.open(BytesIO(png_string))
|
||||
imdata = np.asarray(im.convert('L'))
|
||||
return imdata
|
||||
|
||||
|
||||
def get_direction(diff):
|
||||
"""Returns directions and weights for 8-directional features.
|
||||
|
||||
For more detail, see
|
||||
|
||||
- Bai, Zhen-Long, and Qiang Huo. "A study on the use of 8-directional features
|
||||
for online handwritten Chinese character recognition."
|
||||
- Liu, Cheng-Lin, and Xiang-Dong Zhou. "Online Japanese character recognition
|
||||
using trajectory-based normalization and direction feature extraction."
|
||||
|
||||
Weight is halved for pen-up states.
|
||||
|
||||
Args:
|
||||
diff (numpy.array): Encoded diff vector (dx, dy, dt, pendown).
|
||||
|
||||
Returns:
|
||||
First direction (Right (0), Down, (2), Left (4), Up (6)) and its weight, and
|
||||
Second direction (Bottom right (1), Bottom left (3), Up left (5), Up right
|
||||
(7)) and its weight.
|
||||
|
||||
"""
|
||||
|
||||
if np.abs(diff[0]) >= np.abs(diff[1]):
|
||||
if diff[0] >= 0:
|
||||
direction1 = 0
|
||||
else:
|
||||
direction1 = 4
|
||||
else:
|
||||
if diff[1] >= 0:
|
||||
direction1 = 2
|
||||
else:
|
||||
direction1 = 6
|
||||
|
||||
if diff[0] >= 0:
|
||||
if diff [1] >= 0:
|
||||
direction2 = 1
|
||||
else:
|
||||
direction2 = 7
|
||||
else:
|
||||
if diff [1] >= 0:
|
||||
direction2 = 3
|
||||
else:
|
||||
direction2 = 5
|
||||
length = np.linalg.norm(diff[:2])
|
||||
if length == 0: return 0, 0, 1, 0
|
||||
weight1 = np.abs(np.abs(diff[0]) - np.abs(diff[1])) / length
|
||||
weight2 = np.sqrt(2) * min(np.abs(diff[0]), np.abs(diff[1])) / length
|
||||
if diff[3] == 0:
|
||||
weight1 /= 2
|
||||
weight2 /= 2
|
||||
return direction1, weight1, direction2, weight2
|
||||
|
||||
|
||||
def generate_images(x_norm, x_diff_encoded, directional_feature,
|
||||
temporal_feature, scale, stroke_width):
|
||||
"""Generates image array from strokes.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
directional_feature (boolean): True when using direcitonal feature.
|
||||
temporal_feature (boolean): True when using temporal feature.
|
||||
scale (int): Scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
images (numpy.array): An array of images. Each image should have a shape of
|
||||
(scale, scale). Eight images will be added into the returned array if
|
||||
|directional_feature| is True, otherwise one original image will be
|
||||
added. Also, two images will be generated if |temporal_feature| is True.
|
||||
For example, the shape of |images| will be (scale, scale, 10) when both of
|
||||
options are True.
|
||||
"""
|
||||
|
||||
if directional_feature:
|
||||
images = generate_image_direct_decomp(
|
||||
x_norm, x_diff_encoded, scale, stroke_width)
|
||||
else:
|
||||
images = generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width)
|
||||
if temporal_feature:
|
||||
image = generate_image_temporal(
|
||||
x_norm, x_diff_encoded, scale, stroke_width, inversed=False)
|
||||
images = np.concatenate((images, image), axis=-1)
|
||||
image = generate_image_temporal(
|
||||
x_norm, x_diff_encoded, scale, stroke_width, inversed=True)
|
||||
images = np.concatenate((images, image), axis=-1)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def generate_image_direct_decomp(x_norm, x_diff_encoded, scale, stroke_width):
|
||||
"""Generates image array from strokes using direction feature.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
scale (int): scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
image (numpy.array): Image array with a shape of (scale, scale, 8).
|
||||
"""
|
||||
|
||||
surfaces = [cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
for _ in range(8)]
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
direction1, weight1, direction2, weight2 = get_direction(diff)
|
||||
|
||||
ctx = cairo.Context(surfaces[direction1])
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
ctx.set_source_rgba(1, 1, 1, weight1)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
ctx = cairo.Context(surfaces[direction2])
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
ctx.set_source_rgba(1, 1, 1, weight2)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
|
||||
return np.array([
|
||||
surface_to_array(surface) for surface in surfaces]).transpose(1, 2, 0)
|
||||
|
||||
|
||||
def generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width):
|
||||
"""Generates image array from strokes without direction feature.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
scale (int): scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
image (numpy.array): Image array with a shape of (scale, scale, 1).
|
||||
"""
|
||||
|
||||
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
ctx = cairo.Context(surface)
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
if diff[3] == 1:
|
||||
ctx.set_source_rgba(1, 1, 1, 1)
|
||||
else:
|
||||
ctx.set_source_rgba(1, 1, 1, 0.5)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
|
||||
return surface_to_array(surface).reshape(scale, scale, 1)
|
||||
|
||||
|
||||
def generate_image_temporal(x_norm, x_diff_encoded, scale, stroke_width,
|
||||
steepness=2, inversed=False):
|
||||
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
spent_t = 0
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
ctx = cairo.Context(surface)
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
weight = 1 - spent_t / x_norm[-1][2]
|
||||
if inversed: weight = 1 - weight
|
||||
weight = max(weight, 0) ** steepness
|
||||
if diff[3] == 0: weight /= 2
|
||||
ctx.set_source_rgba(1, 1, 1, weight)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
spent_t += diff[2]
|
||||
return surface_to_array(surface).reshape(scale, scale, 1)
|
||||
|
||||
|
||||
def split_data(x, t, val_rate, test_rate):
|
||||
"""Splits data into training, validation, and testing data.
|
||||
|
||||
Args:
|
||||
x: Data to split.
|
||||
t: Label to split.
|
||||
val_rate: What percentage of data to use as a validation set.
|
||||
test_rate: What percentage of data to use as a testing set.
|
||||
|
||||
Returns:
|
||||
train_x: Training inputs.
|
||||
train_t: Training labels.
|
||||
val_x: Validation inputs.
|
||||
val_t: Validation labels.
|
||||
test_x: Testing inputs.
|
||||
test_t: Testing labels.
|
||||
"""
|
||||
|
||||
n = x.shape[0]
|
||||
train_x = x[:int(n * (1 - val_rate - test_rate))]
|
||||
train_t = t[:int(n * (1 - val_rate - test_rate))]
|
||||
val_x= x[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
|
||||
val_t = t[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
|
||||
test_x = x[int(n * (1 - test_rate)):]
|
||||
test_t = t[int(n * (1 - test_rate)):]
|
||||
return train_x, train_t, val_x, val_t, test_x, test_t
|
||||
|
||||
|
||||
def keydowns2image(keydowns, directional_feature, temporal_feature, scale=16,
|
||||
stroke_width=2):
|
||||
"""Converts a list of keydowns into image.
|
||||
|
||||
Args:
|
||||
keydowns: [(key, t), ...] Training data as a list of keydowns.
|
||||
directional_feature (boolean): True when using directional feature.
|
||||
temporal_feature (boolean): True when using temporal feature.
|
||||
scale (int): Scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
X_im: Image dataset in numpy array format. The shape differs by used
|
||||
features.
|
||||
(directional=True, temporal=True) => (scale, scale, 10)
|
||||
(directional=True, temporal=False) => (scale, scale, 8)
|
||||
(directional=False, temporal=True) => (scale, scale, 3)
|
||||
(directional=False, temporal=False) => (scale, scale, 1)
|
||||
"""
|
||||
|
||||
# Translate keys to 2D points. {(key, t), ...} -> {(x, y, t), ...}
|
||||
X = keydowns2points(keydowns)
|
||||
|
||||
# 0-1 normalization
|
||||
X_norm = normalize_x(X)
|
||||
|
||||
# Take difference. {(x, y, t), ...} -> {(dx, dy, dt), ...}.
|
||||
X_diff = np.diff(X_norm, axis=0)
|
||||
|
||||
# Encode pendown state. {(dx, dy, dt), ...} -> {(dx, dy, dt, pendown), ...}
|
||||
X_diff_encoded = pendown_encode(X_diff)
|
||||
|
||||
# Render into images.
|
||||
X_im = generate_images(X_norm, X_diff_encoded, directional_feature,
|
||||
temporal_feature, scale, stroke_width) / 255.
|
||||
|
||||
return X_im
|
||||
30
mozc-nazoru/src/nazoru/nazoru_test.py
Normal file
30
mozc-nazoru/src/nazoru/nazoru_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import nazoru
|
||||
|
||||
class TestNazoru(unittest.TestCase):
|
||||
"""Test class for nazoru.py"""
|
||||
|
||||
def test_key2pos(self):
|
||||
pos = nazoru.key2pos('a', ['abc', 'def'], 0.5)
|
||||
self.assertEqual((0,0), pos)
|
||||
pos = nazoru.key2pos('e', ['abc', 'def'], 0.5)
|
||||
self.assertEqual((1.5,1), pos)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
259
mozc-nazoru/src/nazoru/nazorunet.py
Normal file
259
mozc-nazoru/src/nazoru/nazorunet.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""NazoruNet.
|
||||
|
||||
NazoruNet is a customized version of MobileNet architecture and can be used to
|
||||
predict input characters from visualized trace.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
from tensorflow.contrib import slim
|
||||
import tensorflow as tf
|
||||
|
||||
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
||||
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
||||
|
||||
|
||||
def mobilenet_v1_base(inputs,
|
||||
final_endpoint,
|
||||
conv_defs,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
bn_decay=0.95,
|
||||
output_stride=None,
|
||||
scope=None):
|
||||
"""Mobilenet v1.
|
||||
Constructs a Mobilenet v1 network from inputs to the given final endpoint.
|
||||
Customized to accept batch normalization decay parameter to accelerate
|
||||
learning speed.
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
final_endpoint: specifies the endpoint to construct the network up to. It
|
||||
can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
|
||||
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise,
|
||||
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
|
||||
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
|
||||
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'].
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
bn_decay: Decay parameter for batch normalization layer.
|
||||
output_stride: An integer that specifies the requested ratio of input to
|
||||
output spatial resolution. If not None, then we invoke atrous convolution
|
||||
if necessary to prevent the network from reducing the spatial resolution
|
||||
of the activation maps. Allowed values are 8 (accurate fully convolutional
|
||||
mode), 16 (fast fully convolutional mode), 32 (classification mode).
|
||||
scope: Optional variable_scope.
|
||||
Returns:
|
||||
tensor_out: output tensor corresponding to the final_endpoint.
|
||||
end_points: a set of activations for external use, for example summaries or
|
||||
losses.
|
||||
Raises:
|
||||
ValueError: if final_endpoint is not set to one of the predefined values,
|
||||
or depth_multiplier <= 0, or the target output_stride is not
|
||||
allowed.
|
||||
"""
|
||||
depth = lambda d: max(int(d * depth_multiplier), min_depth)
|
||||
end_points = {}
|
||||
|
||||
# Used to find thinned depths for each layer.
|
||||
if depth_multiplier <= 0:
|
||||
raise ValueError('depth_multiplier is not greater than zero.')
|
||||
|
||||
if output_stride is not None and output_stride not in [8, 16, 32]:
|
||||
raise ValueError('Only allowed output_stride values are 8, 16, 32.')
|
||||
|
||||
with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
|
||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'):
|
||||
# The current_stride variable keeps track of the output stride of the
|
||||
# activations, i.e., the running product of convolution strides up to the
|
||||
# current network layer. This allows us to invoke atrous convolution
|
||||
# whenever applying the next convolution would result in the activations
|
||||
# having output stride larger than the target output_stride.
|
||||
current_stride = 1
|
||||
|
||||
# The atrous convolution rate parameter.
|
||||
rate = 1
|
||||
|
||||
net = inputs
|
||||
for i, conv_def in enumerate(conv_defs):
|
||||
end_point_base = 'Conv2d_%d' % i
|
||||
|
||||
if output_stride is not None and current_stride == output_stride:
|
||||
# If we have reached the target output_stride, then we need to employ
|
||||
# atrous convolution with stride=1 and multiply the atrous rate by the
|
||||
# current unit's stride for use in subsequent layers.
|
||||
layer_stride = 1
|
||||
layer_rate = rate
|
||||
rate *= conv_def.stride
|
||||
else:
|
||||
layer_stride = conv_def.stride
|
||||
layer_rate = 1
|
||||
current_stride *= conv_def.stride
|
||||
|
||||
if isinstance(conv_def, Conv):
|
||||
end_point = end_point_base
|
||||
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
|
||||
stride=conv_def.stride,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
elif isinstance(conv_def, DepthSepConv):
|
||||
end_point = end_point_base + '_depthwise'
|
||||
|
||||
# By passing filters=None
|
||||
# separable_conv2d produces only a depthwise convolution layer
|
||||
net = slim.separable_conv2d(net, None, conv_def.kernel,
|
||||
depth_multiplier=1,
|
||||
stride=layer_stride,
|
||||
rate=layer_rate,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
end_point = end_point_base + '_pointwise'
|
||||
|
||||
net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
|
||||
stride=1,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
else:
|
||||
raise ValueError('Unknown convolution type %s for layer %d'
|
||||
% (conv_def.ltype, i))
|
||||
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
||||
|
||||
|
||||
def nazorunet(inputs,
|
||||
num_classes,
|
||||
conv_defs,
|
||||
dropout_keep_prob=0.999,
|
||||
is_training=True,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
prediction_fn=tf.contrib.layers.softmax,
|
||||
spatial_squeeze=True,
|
||||
reuse=None,
|
||||
scope='NazoruNet',
|
||||
global_pool=False):
|
||||
"""Customized MobileNet model for Nazoru Input.
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes. If 0 or None, the logits layer
|
||||
is omitted and the input features to the logits layer (before dropout)
|
||||
are returned instead.
|
||||
dropout_keep_prob: the percentage of activation values that are retained.
|
||||
is_training: whether is training or not.
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
prediction_fn: a function to get predictions out of logits.
|
||||
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
|
||||
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
global_pool: Optional boolean flag to control the avgpooling before the
|
||||
logits layer. If false or unset, pooling is done with a fixed window
|
||||
that reduces default-sized inputs to 1x1, while larger inputs lead to
|
||||
larger outputs. If true, any input size is pooled down to 1x1.
|
||||
Returns:
|
||||
net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
|
||||
is a non-zero integer, or the non-dropped-out input to the logits layer
|
||||
if num_classes is 0 or None.
|
||||
end_points: a dictionary from components of the network to the corresponding
|
||||
activation.
|
||||
Raises:
|
||||
ValueError: Input rank is invalid.
|
||||
"""
|
||||
input_shape = inputs.get_shape().as_list()
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
|
||||
len(input_shape))
|
||||
|
||||
with tf.variable_scope(scope, 'NazoruNet', [inputs], reuse=reuse) as scope:
|
||||
with slim.arg_scope([slim.batch_norm, slim.dropout],
|
||||
is_training=is_training):
|
||||
final_endpoint = 'Conv2d_%i_pointwise' % (len(conv_defs) - 1)
|
||||
net, end_points = mobilenet_v1_base(inputs, scope=scope,
|
||||
min_depth=min_depth,
|
||||
depth_multiplier=depth_multiplier,
|
||||
conv_defs=conv_defs,
|
||||
final_endpoint=final_endpoint)
|
||||
with tf.variable_scope('Logits'):
|
||||
if global_pool:
|
||||
# Global average pooling.
|
||||
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
|
||||
end_points['global_pool'] = net
|
||||
else:
|
||||
# Pooling with a fixed kernel size.
|
||||
kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
|
||||
net = slim.avg_pool2d(net, kernel_size, padding='VALID',
|
||||
scope='AvgPool_1a')
|
||||
end_points['AvgPool_1a'] = net
|
||||
if not num_classes:
|
||||
return net, end_points
|
||||
# 1 x 1 x 1024
|
||||
net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
|
||||
logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope='Conv2d_1c_1x1')
|
||||
if spatial_squeeze:
|
||||
logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
|
||||
end_points['Logits'] = logits
|
||||
if prediction_fn:
|
||||
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
||||
return logits, end_points
|
||||
|
||||
|
||||
def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
||||
"""Define kernel size which is automatically reduced for small input.
|
||||
If the shape of the input images is unknown at graph construction time this
|
||||
function assumes that the input images are large enough.
|
||||
Args:
|
||||
input_tensor: input tensor of size [batch_size, height, width, channels].
|
||||
kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
|
||||
Returns:
|
||||
a tensor with the kernel size.
|
||||
"""
|
||||
shape = input_tensor.get_shape().as_list()
|
||||
if shape[1] is None or shape[2] is None:
|
||||
kernel_size_out = kernel_size
|
||||
else:
|
||||
kernel_size_out = [min(shape[1], kernel_size[0]),
|
||||
min(shape[2], kernel_size[1])]
|
||||
return kernel_size_out
|
||||
64
mozc-nazoru/src/nazoru/predictor.py
Normal file
64
mozc-nazoru/src/nazoru/predictor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
from . import lib
|
||||
from . import utils
|
||||
import numpy as np
|
||||
|
||||
def _load_graph(model_file):
|
||||
graph = tf.Graph()
|
||||
graph_def = tf.GraphDef()
|
||||
with open(model_file, "rb") as f:
|
||||
graph_def.ParseFromString(f.read())
|
||||
with graph.as_default():
|
||||
tf.import_graph_def(graph_def)
|
||||
return graph
|
||||
|
||||
class NazoruPredictor():
|
||||
def __init__(self, model_file):
|
||||
graph = _load_graph(model_file)
|
||||
self._graph = graph
|
||||
self._input_operation = graph.get_operation_by_name(
|
||||
'import/' + lib.INPUT_NODE_NAME)
|
||||
self._output_operation = graph.get_operation_by_name(
|
||||
'import/' + lib.OUTPUT_NODE_NAME)
|
||||
|
||||
def _predict(self, data):
|
||||
with utils.Measure('inputs'):
|
||||
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
||||
inputs = np.expand_dims(inputs, axis=0)
|
||||
with utils.Measure('sess.run'):
|
||||
with tf.Session(graph=self._graph) as sess:
|
||||
result = sess.run(self._output_operation.outputs[0],
|
||||
{self._input_operation.outputs[0]: inputs})[0]
|
||||
return result
|
||||
def predict_top_n(self, data, n):
|
||||
"""Predict the charactor drawn by |data|.
|
||||
|
||||
Args:
|
||||
data: [(key, time)] |time| is elapsed time since the first character in ms.
|
||||
n: integer of the number of the return value.
|
||||
Returns:
|
||||
ans: [(kana, key, probability)] sorted by the probability.
|
||||
"""
|
||||
result = self._predict(data)
|
||||
ans = []
|
||||
for i in result.argsort()[::-1][:n]:
|
||||
ans.append((lib.KANAS[i], lib.KEYS[i], result[i]))
|
||||
return ans
|
||||
28
mozc-nazoru/src/nazoru/utils.py
Normal file
28
mozc-nazoru/src/nazoru/utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
class Measure():
|
||||
def __init__(self, tag):
|
||||
self._tag = tag
|
||||
|
||||
def __enter__(self):
|
||||
self._start = time.time()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
now = time.time()
|
||||
print('[{0}] {1} ms'.format(self._tag, (now - self._start)*1E3))
|
||||
Reference in New Issue
Block a user