Add nazoru-input in mozc-nazoru/

This commit is contained in:
Makoto Shimazu
2018-03-16 13:27:36 +09:00
parent 54d39f5615
commit e006ebd113
21 changed files with 5247 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
def get_default_graph_path():
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(script_dir, 'data', 'optimized_nazoru.pb')

View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import serial
import struct
class Bluetooth():
DEVICE_FILE = '/dev/serial0'
BAUDRATE = 115200
def __init__(self):
try:
self._conn = serial.Serial(self.DEVICE_FILE, self.BAUDRATE)
self._dummy = False
print('Bluetooth')
except serial.SerialException:
self._conn = None
self._dummy = True
print('Dummy Bluetooth')
def send(self, string):
"""Send |string| as a series of characters. |string| should be
alphabets, numbers and symbols which can be typed from your keyboard."""
if self._dummy:
print('bluetooth: {}'.format(string))
return
self._conn.write(string)
# See http://ww1.microchip.com/downloads/en/DeviceDoc/bluetooth_cr_UG-v1.0r.pdf
# for detail.
UART_CODES = {
'KEY_RIGHT': 7,
'KEY_BACKSPACE': 8,
'KEY_ENTER': 10,
'KEY_LEFT': 11,
'KEY_DOWN': 12,
'KEY_UP': 14,
}
def command(self, cmd):
if cmd not in self.UART_CODES:
print('Unknown Command: {}'.format(cmd))
return
self.send(struct.pack('b', self.UART_CODES[cmd]))

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import lib
from .bluetooth import Bluetooth
from .keyboard_recorder import create_keyboard_recorder
from .nazorunet import nazorunet
from .nazorunet import Conv
from .nazorunet import DepthSepConv
from .predictor import NazoruPredictor

Binary file not shown.

View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from enum import Enum
import datetime
import fcntl
import os
import struct
import sys
import termios
try:
import evdev
except ImportError:
evdev = None
def _set_raw_mode_stdin():
fno = sys.stdin.fileno()
attr_old = termios.tcgetattr(fno)
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
attr_new = termios.tcgetattr(fno)
attr_new[3] = attr_new[3] & ~termios.ECHO & ~termios.ICANON
termios.tcsetattr(fno, termios.TCSADRAIN, attr_new)
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
def reset_raw_mode():
termios.tcsetattr(fno, termios.TCSANOW, attr_old)
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
return reset_raw_mode
def _set_raw_mode_general(f):
fno = f.fileno()
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
def reset_raw_mode():
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
return reset_raw_mode
class KeyboardRecorder():
def __init__(self, verbose=False):
self._verbose = verbose
def record(self):
return (None, None)
def log(self, *args, **kwargs):
# TODO(shimazu): Use logger.
if self._verbose:
print(*args, file=sys.stderr, **kwargs)
class KeyboardRecorderFromConsole(KeyboardRecorder):
def __init__(self, verbose=False):
KeyboardRecorder.__init__(self, verbose)
self.log('Input from console')
def record(self):
"""
Returns a tuple of |data| and |command|.
|data|: an array of tuples of keys and time from the first character.
|command|: None
"""
recording = False
start_time = None
last_time = None
wait_seconds = 2
data = []
reset_raw_mode = _set_raw_mode_stdin()
try:
while 1:
try:
key = sys.stdin.read(1)
except IOError:
key = None
finally:
now = datetime.datetime.now()
if key == '\n':
return (None, None)
elif key:
if not recording:
recording = True
start_time = datetime.datetime.now()
elapsed_time = now - start_time
elapsed_ms = int(elapsed_time.total_seconds() * 1000)
last_time = now
data.append((key, elapsed_ms))
self.log(key, elapsed_ms)
if last_time and (now - last_time).total_seconds() > wait_seconds:
break
finally:
reset_raw_mode()
return (data, None)
class KeyboardRecorderFromEvdev(KeyboardRecorder):
KEYS = {
2: '1',
3: '2',
4: '3',
5: '4',
6: '5',
7: '6',
8: '7',
9: '8',
10: '9',
11: '0',
12: '-',
13: '=',
16: 'q',
17: 'w',
18: 'e',
19: 'r',
20: 't',
21: 'y',
22: 'u',
23: 'i',
24: 'o',
25: 'p',
26: '[',
27: ']',
30: 'a',
31: 's',
32: 'd',
33: 'f',
34: 'g',
35: 'h',
36: 'j',
37: 'k',
38: 'l',
39: ';',
40: '\'',
43: '\\',
44: 'z',
45: 'x',
46: 'c',
47: 'v',
48: 'b',
49: 'n',
50: 'm',
51: ',',
52: '.',
53: '/'
}
WAIT_SECONDS = 2
def __init__(self, verbose=False):
if evdev is None:
raise TypeError('KeyboardRecorderFromEvdev needs to be used on Linux ' +
'(or POSIX compatible) system. Instead, You can try it ' +
'on your console.')
KeyboardRecorder.__init__(self, verbose)
self.log('Input from evdev')
keyboards = []
ecode_ev_key = evdev.ecodes.ecodes['EV_KEY']
ecode_key_esc = evdev.ecodes.ecodes['KEY_ESC']
for device in [evdev.InputDevice(fn) for fn in evdev.list_devices()]:
# TODO(shimazu): Consider more solid criteria to get 'regular' keyboards.
if ecode_ev_key in device.capabilities() and \
ecode_key_esc in device.capabilities()[ecode_ev_key]:
keyboards.append(device)
if len(keyboards) == 0:
raise IOError('No keyboard found.')
self._keyboards = keyboards
for keyboard in keyboards:
self.log('----')
self.log(keyboard)
self.log('name: {0}'.format(keyboard.name))
self.log('phys: {0}'.format(keyboard.phys))
self.log('repeat: {0}'.format(keyboard.repeat))
self.log('info: {0}'.format(keyboard.info))
self.log(keyboard.capabilities(verbose=True))
def record(self):
"""
Returns a tuple of |data| and |command|.
|data|: an array of tuples of keys and time from the first character. None
if the input is non-alphabet/numeric/symbols like ENTER, arrows etc.
|command|: Commands like "KEY_ENTER" or None if |data| is valid.
"""
start_time = None
last_time = None
data = []
while True:
# TODO(shimazu): Check inputs from all keyboards.
event = self._keyboards[0].read_one()
now = datetime.datetime.now()
if last_time and (now - last_time).total_seconds() > self.WAIT_SECONDS:
break
if event is None:
continue
name = evdev.ecodes.bytype[event.type][event.code]
ev_type = evdev.ecodes.EV[event.type]
if ev_type != 'EV_KEY':
continue
# Keyboard input
self.log('----')
self.log(event)
self.log('name: {}'.format(name))
self.log('type: {}'.format(ev_type))
# Check if the event is from releasing the button
if event.value == 0:
continue
# It may be a non-alphabet/numeric/symbol key. Return it as a command.
if event.code not in self.KEYS:
if start_time is not None:
continue
return (None, name)
last_time = now
if start_time is None:
start_time = now
elapsed_ms = int((now - start_time).total_seconds() * 1000)
data.append((self.KEYS[event.code], elapsed_ms))
return (data, None)
InputSource = Enum('InputSource', 'EVDEV CONSOLE')
def create_keyboard_recorder(verbose=False, source=None):
"""Creates KeyboardRecorder.
Args:
verbose: Print the detail of input when it's true.
source: InputSource.EVDEV, InputSource.CONSOLE or None (default)
Returns:
recorder: Corresponding KeyboardRecorder. If |source| is None, returns
KeyboardRecorderFromConsole when stdin is attached to a console (isatty is
true). Otherwise, returns KeyboardRecorderFromEvdev.
"""
if source == InputSource.CONSOLE:
return KeyboardRecorderFromConsole(verbose=verbose)
if source == InputSource.EVDEV:
return KeyboardRecorderFromEvdev(verbose=verbose)
if sys.__stdin__.isatty():
return KeyboardRecorderFromConsole(verbose=verbose)
return KeyboardRecorderFromEvdev(verbose=verbose)

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LEDBase():
def __init__(self, pin):
self._pin = pin
def on(self):
pass
def off(self):
pass
def blink(self, interval):
pass
try:
import RPi.GPIO as GPIO
import threading
class LED(LEDBase):
ON = 'ON'
OFF = 'OFF'
def __init__(self, pin):
GPIO.setmode(GPIO.BOARD)
self._pin = pin
self._lock = threading.Lock()
self._timer = None
GPIO.setup(pin, GPIO.OUT)
self.off()
def on(self):
with self._lock:
self._state = self.ON
GPIO.output(self._pin, False)
self._ensure_stop_timer()
def off(self):
with self._lock:
self._state = self.OFF
GPIO.output(self._pin, True)
self._ensure_stop_timer()
def blink(self, interval):
self._ensure_stop_timer()
def toggle():
self._timer = None
if self._state == self.ON:
self.off()
else:
self.on()
self._timer = threading.Timer(interval, toggle)
self._timer.daemon = True
self._timer.start()
toggle()
def _ensure_stop_timer(self):
if self._timer is not None:
self._timer.cancel()
self._timer = None
except ImportError as e:
import sys
class LED(LEDBase):
pass
LED_BLUE = LED(38)
LED_RED = LED(40)

View File

@@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Nazoru input library.
This is a collection of methods to preprocess input stroke data before any
training starts.
"""
import time
import random
import cairocffi as cairo
import numpy as np
from PIL import Image
from io import BytesIO
from enum import Enum
SCOPE = 'Nazorunet'
INPUT_NODE_NAME = 'inputs'
OUTPUT_NODE_NAME = SCOPE + '/Predictions/Reshape_1'
KANAS = (u'あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほ'
u'まみむめもやゆよらりるれろわゐんゑを'
u'abcdefghijklmnopqrstuvwxyz1234567890'
u'♡ーずぐ')
KEYS = ('a', 'i', 'u', 'e', 'o',
'ka', 'ki', 'ku', 'ke', 'ko',
'sa', 'si', 'su', 'se', 'so',
'ta', 'ti', 'tu', 'te', 'to',
'na', 'ni', 'nu', 'ne', 'no',
'ha', 'hi', 'hu', 'he', 'ho',
'ma', 'mi', 'mu', 'me', 'mo',
'ya', 'yu', 'yo',
'ra', 'ri', 'ru', 're', 'ro',
'wa', 'wi', 'nn', 'we', 'wo',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'1', '2', '3', '4', '5', '6', '7', '8', '9', '0',
'ha-to', '-', 'zu', 'gu')
class KeyboardArrangement(Enum):
"""Enum for keyboard arrangements.
"""
qwerty_jis = [
u'1234567890-^¥',
u'qwertyuiop@[',
u'asdfghjkl;:]',
u'zxcvbnm,./_',
]
def key2pos(key, arrangement=KeyboardArrangement.qwerty_jis.value, offset=0.5):
"""Returns the key position.
Args:
key (string): Key to get position.
arrangement (list): Keyboard arrangement.
offset (number): How much the keys are shifting by row.
Returns:
position (tuple(number, number)): Position (x, y).
"""
for i, row in enumerate(arrangement):
if key in row:
y = i
x = row.index(key) + i * offset
return (x, y)
return None
def keydowns2points(keydowns):
"""Translates keydowns to points.
Args:
keydowns: [(key, t), ...] List of keydowns.
Returns:
points: [(x, y, t), ...] List of points.
"""
points = []
for keydown in keydowns:
pos = key2pos(keydown[0])
if pos:
points.append((pos[0], pos[1], keydown[1]))
return points
def normalize_x(x):
"""Normalizes position.
Args:
x (list): [[x, y, t], ...] List of points to normalize the position (x, y)
into 0-1 range.
Returns:
x (list): [[x', y', t], ...] List of points with the normalized potision
(x', y').
"""
x = np.array(x)
max_ = np.max(x[:, :2], axis=0)
min_ = np.min(x[:, :2], axis=0)
x[:, :2] = (x[:, :2] - min_) / (max_ - min_)
return x
def pendown_encode(x_diff, sigma=1.6):
"""Encodes time difference into pendown state.
Args:
x_diff (list): [[dx, dy, dt], ...] List of diffs to encode.
Returns:
x_diff_encoded (list): [[dx, dy, dt, pendown], ...] Encoded list of diffs.
"""
thres = np.mean(x_diff[:,2]) + sigma * np.std(x_diff[:,2])
x_diff_encoded = np.concatenate((
x_diff,
[[0] if dt_i > thres else [1] for dt_i in x_diff[:, 2]]
), axis=1)
return x_diff_encoded
def surface_to_array(surface):
"""Returns image array from cairo surface.
Args:
surface: Cairo surface to translate.
"""
buf = BytesIO()
surface.write_to_png(buf)
png_string = buf.getvalue()
im = Image.open(BytesIO(png_string))
imdata = np.asarray(im.convert('L'))
return imdata
def get_direction(diff):
"""Returns directions and weights for 8-directional features.
For more detail, see
- Bai, Zhen-Long, and Qiang Huo. "A study on the use of 8-directional features
for online handwritten Chinese character recognition."
- Liu, Cheng-Lin, and Xiang-Dong Zhou. "Online Japanese character recognition
using trajectory-based normalization and direction feature extraction."
Weight is halved for pen-up states.
Args:
diff (numpy.array): Encoded diff vector (dx, dy, dt, pendown).
Returns:
First direction (Right (0), Down, (2), Left (4), Up (6)) and its weight, and
Second direction (Bottom right (1), Bottom left (3), Up left (5), Up right
(7)) and its weight.
"""
if np.abs(diff[0]) >= np.abs(diff[1]):
if diff[0] >= 0:
direction1 = 0
else:
direction1 = 4
else:
if diff[1] >= 0:
direction1 = 2
else:
direction1 = 6
if diff[0] >= 0:
if diff [1] >= 0:
direction2 = 1
else:
direction2 = 7
else:
if diff [1] >= 0:
direction2 = 3
else:
direction2 = 5
length = np.linalg.norm(diff[:2])
if length == 0: return 0, 0, 1, 0
weight1 = np.abs(np.abs(diff[0]) - np.abs(diff[1])) / length
weight2 = np.sqrt(2) * min(np.abs(diff[0]), np.abs(diff[1])) / length
if diff[3] == 0:
weight1 /= 2
weight2 /= 2
return direction1, weight1, direction2, weight2
def generate_images(x_norm, x_diff_encoded, directional_feature,
temporal_feature, scale, stroke_width):
"""Generates image array from strokes.
Args:
x_norm: [(x', y', t), ...] Normalized points.
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
directional_feature (boolean): True when using direcitonal feature.
temporal_feature (boolean): True when using temporal feature.
scale (int): Scale of the image.
stroke_width (int): Brush thickness to draw.
Returns:
images (numpy.array): An array of images. Each image should have a shape of
(scale, scale). Eight images will be added into the returned array if
|directional_feature| is True, otherwise one original image will be
added. Also, two images will be generated if |temporal_feature| is True.
For example, the shape of |images| will be (scale, scale, 10) when both of
options are True.
"""
if directional_feature:
images = generate_image_direct_decomp(
x_norm, x_diff_encoded, scale, stroke_width)
else:
images = generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width)
if temporal_feature:
image = generate_image_temporal(
x_norm, x_diff_encoded, scale, stroke_width, inversed=False)
images = np.concatenate((images, image), axis=-1)
image = generate_image_temporal(
x_norm, x_diff_encoded, scale, stroke_width, inversed=True)
images = np.concatenate((images, image), axis=-1)
return images
def generate_image_direct_decomp(x_norm, x_diff_encoded, scale, stroke_width):
"""Generates image array from strokes using direction feature.
Args:
x_norm: [(x', y', t), ...] Normalized points.
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
scale (int): scale of the image.
stroke_width (int): Brush thickness to draw.
Returns:
image (numpy.array): Image array with a shape of (scale, scale, 8).
"""
surfaces = [cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
for _ in range(8)]
curr_x = x_norm[0][0]
curr_y = x_norm[0][1]
for i, diff in enumerate(x_diff_encoded):
direction1, weight1, direction2, weight2 = get_direction(diff)
ctx = cairo.Context(surfaces[direction1])
ctx.move_to(curr_x * scale, curr_y * scale)
ctx.set_line_width(stroke_width)
ctx.set_source_rgba(1, 1, 1, weight1)
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
ctx.stroke()
ctx = cairo.Context(surfaces[direction2])
ctx.move_to(curr_x * scale, curr_y * scale)
ctx.set_line_width(stroke_width)
ctx.set_source_rgba(1, 1, 1, weight2)
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
ctx.stroke()
curr_x += diff[0]
curr_y += diff[1]
return np.array([
surface_to_array(surface) for surface in surfaces]).transpose(1, 2, 0)
def generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width):
"""Generates image array from strokes without direction feature.
Args:
x_norm: [(x', y', t), ...] Normalized points.
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
scale (int): scale of the image.
stroke_width (int): Brush thickness to draw.
Returns:
image (numpy.array): Image array with a shape of (scale, scale, 1).
"""
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
curr_x = x_norm[0][0]
curr_y = x_norm[0][1]
for i, diff in enumerate(x_diff_encoded):
ctx = cairo.Context(surface)
ctx.move_to(curr_x * scale, curr_y * scale)
ctx.set_line_width(stroke_width)
if diff[3] == 1:
ctx.set_source_rgba(1, 1, 1, 1)
else:
ctx.set_source_rgba(1, 1, 1, 0.5)
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
ctx.stroke()
curr_x += diff[0]
curr_y += diff[1]
return surface_to_array(surface).reshape(scale, scale, 1)
def generate_image_temporal(x_norm, x_diff_encoded, scale, stroke_width,
steepness=2, inversed=False):
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
curr_x = x_norm[0][0]
curr_y = x_norm[0][1]
spent_t = 0
for i, diff in enumerate(x_diff_encoded):
ctx = cairo.Context(surface)
ctx.move_to(curr_x * scale, curr_y * scale)
ctx.set_line_width(stroke_width)
weight = 1 - spent_t / x_norm[-1][2]
if inversed: weight = 1 - weight
weight = max(weight, 0) ** steepness
if diff[3] == 0: weight /= 2
ctx.set_source_rgba(1, 1, 1, weight)
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
ctx.stroke()
curr_x += diff[0]
curr_y += diff[1]
spent_t += diff[2]
return surface_to_array(surface).reshape(scale, scale, 1)
def split_data(x, t, val_rate, test_rate):
"""Splits data into training, validation, and testing data.
Args:
x: Data to split.
t: Label to split.
val_rate: What percentage of data to use as a validation set.
test_rate: What percentage of data to use as a testing set.
Returns:
train_x: Training inputs.
train_t: Training labels.
val_x: Validation inputs.
val_t: Validation labels.
test_x: Testing inputs.
test_t: Testing labels.
"""
n = x.shape[0]
train_x = x[:int(n * (1 - val_rate - test_rate))]
train_t = t[:int(n * (1 - val_rate - test_rate))]
val_x= x[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
val_t = t[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
test_x = x[int(n * (1 - test_rate)):]
test_t = t[int(n * (1 - test_rate)):]
return train_x, train_t, val_x, val_t, test_x, test_t
def keydowns2image(keydowns, directional_feature, temporal_feature, scale=16,
stroke_width=2):
"""Converts a list of keydowns into image.
Args:
keydowns: [(key, t), ...] Training data as a list of keydowns.
directional_feature (boolean): True when using directional feature.
temporal_feature (boolean): True when using temporal feature.
scale (int): Scale of the image.
stroke_width (int): Brush thickness to draw.
Returns:
X_im: Image dataset in numpy array format. The shape differs by used
features.
(directional=True, temporal=True) => (scale, scale, 10)
(directional=True, temporal=False) => (scale, scale, 8)
(directional=False, temporal=True) => (scale, scale, 3)
(directional=False, temporal=False) => (scale, scale, 1)
"""
# Translate keys to 2D points. {(key, t), ...} -> {(x, y, t), ...}
X = keydowns2points(keydowns)
# 0-1 normalization
X_norm = normalize_x(X)
# Take difference. {(x, y, t), ...} -> {(dx, dy, dt), ...}.
X_diff = np.diff(X_norm, axis=0)
# Encode pendown state. {(dx, dy, dt), ...} -> {(dx, dy, dt, pendown), ...}
X_diff_encoded = pendown_encode(X_diff)
# Render into images.
X_im = generate_images(X_norm, X_diff_encoded, directional_feature,
temporal_feature, scale, stroke_width) / 255.
return X_im

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import nazoru
class TestNazoru(unittest.TestCase):
"""Test class for nazoru.py"""
def test_key2pos(self):
pos = nazoru.key2pos('a', ['abc', 'def'], 0.5)
self.assertEqual((0,0), pos)
pos = nazoru.key2pos('e', ['abc', 'def'], 0.5)
self.assertEqual((1.5,1), pos)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NazoruNet.
NazoruNet is a customized version of MobileNet architecture and can be used to
predict input characters from visualized trace.
"""
from collections import namedtuple
from tensorflow.contrib import slim
import tensorflow as tf
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
def mobilenet_v1_base(inputs,
final_endpoint,
conv_defs,
min_depth=8,
depth_multiplier=1.0,
bn_decay=0.95,
output_stride=None,
scope=None):
"""Mobilenet v1.
Constructs a Mobilenet v1 network from inputs to the given final endpoint.
Customized to accept batch normalization decay parameter to accelerate
learning speed.
Args:
inputs: a tensor of shape [batch_size, height, width, channels].
final_endpoint: specifies the endpoint to construct the network up to. It
can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise,
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'].
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
min_depth: Minimum depth value (number of channels) for all convolution ops.
Enforced when depth_multiplier < 1, and not an active constraint when
depth_multiplier >= 1.
depth_multiplier: Float multiplier for the depth (number of channels)
for all convolution ops. The value must be greater than zero. Typical
usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model.
bn_decay: Decay parameter for batch normalization layer.
output_stride: An integer that specifies the requested ratio of input to
output spatial resolution. If not None, then we invoke atrous convolution
if necessary to prevent the network from reducing the spatial resolution
of the activation maps. Allowed values are 8 (accurate fully convolutional
mode), 16 (fast fully convolutional mode), 32 (classification mode).
scope: Optional variable_scope.
Returns:
tensor_out: output tensor corresponding to the final_endpoint.
end_points: a set of activations for external use, for example summaries or
losses.
Raises:
ValueError: if final_endpoint is not set to one of the predefined values,
or depth_multiplier <= 0, or the target output_stride is not
allowed.
"""
depth = lambda d: max(int(d * depth_multiplier), min_depth)
end_points = {}
# Used to find thinned depths for each layer.
if depth_multiplier <= 0:
raise ValueError('depth_multiplier is not greater than zero.')
if output_stride is not None and output_stride not in [8, 16, 32]:
raise ValueError('Only allowed output_stride values are 8, 16, 32.')
with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'):
# The current_stride variable keeps track of the output stride of the
# activations, i.e., the running product of convolution strides up to the
# current network layer. This allows us to invoke atrous convolution
# whenever applying the next convolution would result in the activations
# having output stride larger than the target output_stride.
current_stride = 1
# The atrous convolution rate parameter.
rate = 1
net = inputs
for i, conv_def in enumerate(conv_defs):
end_point_base = 'Conv2d_%d' % i
if output_stride is not None and current_stride == output_stride:
# If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers.
layer_stride = 1
layer_rate = rate
rate *= conv_def.stride
else:
layer_stride = conv_def.stride
layer_rate = 1
current_stride *= conv_def.stride
if isinstance(conv_def, Conv):
end_point = end_point_base
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
stride=conv_def.stride,
normalizer_fn=slim.batch_norm,
normalizer_params={'decay': bn_decay},
scope=end_point)
end_points[end_point] = net
if end_point == final_endpoint:
return net, end_points
elif isinstance(conv_def, DepthSepConv):
end_point = end_point_base + '_depthwise'
# By passing filters=None
# separable_conv2d produces only a depthwise convolution layer
net = slim.separable_conv2d(net, None, conv_def.kernel,
depth_multiplier=1,
stride=layer_stride,
rate=layer_rate,
normalizer_fn=slim.batch_norm,
normalizer_params={'decay': bn_decay},
scope=end_point)
end_points[end_point] = net
if end_point == final_endpoint:
return net, end_points
end_point = end_point_base + '_pointwise'
net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
stride=1,
normalizer_fn=slim.batch_norm,
normalizer_params={'decay': bn_decay},
scope=end_point)
end_points[end_point] = net
if end_point == final_endpoint:
return net, end_points
else:
raise ValueError('Unknown convolution type %s for layer %d'
% (conv_def.ltype, i))
raise ValueError('Unknown final endpoint %s' % final_endpoint)
def nazorunet(inputs,
num_classes,
conv_defs,
dropout_keep_prob=0.999,
is_training=True,
min_depth=8,
depth_multiplier=1.0,
prediction_fn=tf.contrib.layers.softmax,
spatial_squeeze=True,
reuse=None,
scope='NazoruNet',
global_pool=False):
"""Customized MobileNet model for Nazoru Input.
Args:
inputs: a tensor of shape [batch_size, height, width, channels].
num_classes: number of predicted classes. If 0 or None, the logits layer
is omitted and the input features to the logits layer (before dropout)
are returned instead.
dropout_keep_prob: the percentage of activation values that are retained.
is_training: whether is training or not.
min_depth: Minimum depth value (number of channels) for all convolution ops.
Enforced when depth_multiplier < 1, and not an active constraint when
depth_multiplier >= 1.
depth_multiplier: Float multiplier for the depth (number of channels)
for all convolution ops. The value must be greater than zero. Typical
usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model.
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
prediction_fn: a function to get predictions out of logits.
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
global_pool: Optional boolean flag to control the avgpooling before the
logits layer. If false or unset, pooling is done with a fixed window
that reduces default-sized inputs to 1x1, while larger inputs lead to
larger outputs. If true, any input size is pooled down to 1x1.
Returns:
net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
is a non-zero integer, or the non-dropped-out input to the logits layer
if num_classes is 0 or None.
end_points: a dictionary from components of the network to the corresponding
activation.
Raises:
ValueError: Input rank is invalid.
"""
input_shape = inputs.get_shape().as_list()
if len(input_shape) != 4:
raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
len(input_shape))
with tf.variable_scope(scope, 'NazoruNet', [inputs], reuse=reuse) as scope:
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):
final_endpoint = 'Conv2d_%i_pointwise' % (len(conv_defs) - 1)
net, end_points = mobilenet_v1_base(inputs, scope=scope,
min_depth=min_depth,
depth_multiplier=depth_multiplier,
conv_defs=conv_defs,
final_endpoint=final_endpoint)
with tf.variable_scope('Logits'):
if global_pool:
# Global average pooling.
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
end_points['global_pool'] = net
else:
# Pooling with a fixed kernel size.
kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
net = slim.avg_pool2d(net, kernel_size, padding='VALID',
scope='AvgPool_1a')
end_points['AvgPool_1a'] = net
if not num_classes:
return net, end_points
# 1 x 1 x 1024
net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='Conv2d_1c_1x1')
if spatial_squeeze:
logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
end_points['Logits'] = logits
if prediction_fn:
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
return logits, end_points
def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
"""Define kernel size which is automatically reduced for small input.
If the shape of the input images is unknown at graph construction time this
function assumes that the input images are large enough.
Args:
input_tensor: input tensor of size [batch_size, height, width, channels].
kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
Returns:
a tensor with the kernel size.
"""
shape = input_tensor.get_shape().as_list()
if shape[1] is None or shape[2] is None:
kernel_size_out = kernel_size
else:
kernel_size_out = [min(shape[1], kernel_size[0]),
min(shape[2], kernel_size[1])]
return kernel_size_out

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import tensorflow as tf
from . import lib
from . import utils
import numpy as np
def _load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
class NazoruPredictor():
def __init__(self, model_file):
graph = _load_graph(model_file)
self._graph = graph
self._input_operation = graph.get_operation_by_name(
'import/' + lib.INPUT_NODE_NAME)
self._output_operation = graph.get_operation_by_name(
'import/' + lib.OUTPUT_NODE_NAME)
def _predict(self, data):
with utils.Measure('inputs'):
inputs = lib.keydowns2image(data, True, True, 16, 2)
inputs = np.expand_dims(inputs, axis=0)
with utils.Measure('sess.run'):
with tf.Session(graph=self._graph) as sess:
result = sess.run(self._output_operation.outputs[0],
{self._input_operation.outputs[0]: inputs})[0]
return result
def predict_top_n(self, data, n):
"""Predict the charactor drawn by |data|.
Args:
data: [(key, time)] |time| is elapsed time since the first character in ms.
n: integer of the number of the return value.
Returns:
ans: [(kana, key, probability)] sorted by the probability.
"""
result = self._predict(data)
ans = []
for i in result.argsort()[::-1][:n]:
ans.append((lib.KANAS[i], lib.KEYS[i], result[i]))
return ans

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class Measure():
def __init__(self, tag):
self._tag = tag
def __enter__(self):
self._start = time.time()
def __exit__(self, type, value, traceback):
now = time.time()
print('[{0}] {1} ms'.format(self._tag, (now - self._start)*1E3))