mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-09 01:03:26 +03:00
Add nazoru-input in mozc-nazoru/
This commit is contained in:
6
mozc-nazoru/src/nazoru/__init__.py
Normal file
6
mozc-nazoru/src/nazoru/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
def get_default_graph_path():
|
||||
import os
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(script_dir, 'data', 'optimized_nazoru.pb')
|
||||
57
mozc-nazoru/src/nazoru/bluetooth.py
Normal file
57
mozc-nazoru/src/nazoru/bluetooth.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import serial
|
||||
import struct
|
||||
|
||||
class Bluetooth():
|
||||
DEVICE_FILE = '/dev/serial0'
|
||||
BAUDRATE = 115200
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self._conn = serial.Serial(self.DEVICE_FILE, self.BAUDRATE)
|
||||
self._dummy = False
|
||||
print('Bluetooth')
|
||||
except serial.SerialException:
|
||||
self._conn = None
|
||||
self._dummy = True
|
||||
print('Dummy Bluetooth')
|
||||
|
||||
def send(self, string):
|
||||
"""Send |string| as a series of characters. |string| should be
|
||||
alphabets, numbers and symbols which can be typed from your keyboard."""
|
||||
if self._dummy:
|
||||
print('bluetooth: {}'.format(string))
|
||||
return
|
||||
self._conn.write(string)
|
||||
|
||||
# See http://ww1.microchip.com/downloads/en/DeviceDoc/bluetooth_cr_UG-v1.0r.pdf
|
||||
# for detail.
|
||||
UART_CODES = {
|
||||
'KEY_RIGHT': 7,
|
||||
'KEY_BACKSPACE': 8,
|
||||
'KEY_ENTER': 10,
|
||||
'KEY_LEFT': 11,
|
||||
'KEY_DOWN': 12,
|
||||
'KEY_UP': 14,
|
||||
}
|
||||
|
||||
def command(self, cmd):
|
||||
if cmd not in self.UART_CODES:
|
||||
print('Unknown Command: {}'.format(cmd))
|
||||
return
|
||||
self.send(struct.pack('b', self.UART_CODES[cmd]))
|
||||
26
mozc-nazoru/src/nazoru/core.py
Normal file
26
mozc-nazoru/src/nazoru/core.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from . import lib
|
||||
from .bluetooth import Bluetooth
|
||||
from .keyboard_recorder import create_keyboard_recorder
|
||||
from .nazorunet import nazorunet
|
||||
from .nazorunet import Conv
|
||||
from .nazorunet import DepthSepConv
|
||||
from .predictor import NazoruPredictor
|
||||
|
||||
BIN
mozc-nazoru/src/nazoru/data/optimized_nazoru.pb
Normal file
BIN
mozc-nazoru/src/nazoru/data/optimized_nazoru.pb
Normal file
Binary file not shown.
254
mozc-nazoru/src/nazoru/keyboard_recorder.py
Normal file
254
mozc-nazoru/src/nazoru/keyboard_recorder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import fcntl
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import termios
|
||||
|
||||
try:
|
||||
import evdev
|
||||
except ImportError:
|
||||
evdev = None
|
||||
|
||||
def _set_raw_mode_stdin():
|
||||
fno = sys.stdin.fileno()
|
||||
attr_old = termios.tcgetattr(fno)
|
||||
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
|
||||
|
||||
attr_new = termios.tcgetattr(fno)
|
||||
attr_new[3] = attr_new[3] & ~termios.ECHO & ~termios.ICANON
|
||||
termios.tcsetattr(fno, termios.TCSADRAIN, attr_new)
|
||||
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
|
||||
|
||||
def reset_raw_mode():
|
||||
termios.tcsetattr(fno, termios.TCSANOW, attr_old)
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
|
||||
|
||||
return reset_raw_mode
|
||||
|
||||
def _set_raw_mode_general(f):
|
||||
fno = f.fileno()
|
||||
fcntl_old = fcntl.fcntl(fno, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old | os.O_NONBLOCK)
|
||||
|
||||
def reset_raw_mode():
|
||||
fcntl.fcntl(fno, fcntl.F_SETFL, fcntl_old)
|
||||
|
||||
return reset_raw_mode
|
||||
|
||||
class KeyboardRecorder():
|
||||
def __init__(self, verbose=False):
|
||||
self._verbose = verbose
|
||||
|
||||
def record(self):
|
||||
return (None, None)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
# TODO(shimazu): Use logger.
|
||||
if self._verbose:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
class KeyboardRecorderFromConsole(KeyboardRecorder):
|
||||
def __init__(self, verbose=False):
|
||||
KeyboardRecorder.__init__(self, verbose)
|
||||
self.log('Input from console')
|
||||
|
||||
def record(self):
|
||||
"""
|
||||
Returns a tuple of |data| and |command|.
|
||||
|data|: an array of tuples of keys and time from the first character.
|
||||
|command|: None
|
||||
"""
|
||||
recording = False
|
||||
start_time = None
|
||||
last_time = None
|
||||
wait_seconds = 2
|
||||
data = []
|
||||
reset_raw_mode = _set_raw_mode_stdin()
|
||||
try:
|
||||
while 1:
|
||||
try:
|
||||
key = sys.stdin.read(1)
|
||||
except IOError:
|
||||
key = None
|
||||
finally:
|
||||
now = datetime.datetime.now()
|
||||
if key == '\n':
|
||||
return (None, None)
|
||||
elif key:
|
||||
if not recording:
|
||||
recording = True
|
||||
start_time = datetime.datetime.now()
|
||||
elapsed_time = now - start_time
|
||||
elapsed_ms = int(elapsed_time.total_seconds() * 1000)
|
||||
last_time = now
|
||||
data.append((key, elapsed_ms))
|
||||
self.log(key, elapsed_ms)
|
||||
if last_time and (now - last_time).total_seconds() > wait_seconds:
|
||||
break
|
||||
finally:
|
||||
reset_raw_mode()
|
||||
return (data, None)
|
||||
|
||||
class KeyboardRecorderFromEvdev(KeyboardRecorder):
|
||||
KEYS = {
|
||||
2: '1',
|
||||
3: '2',
|
||||
4: '3',
|
||||
5: '4',
|
||||
6: '5',
|
||||
7: '6',
|
||||
8: '7',
|
||||
9: '8',
|
||||
10: '9',
|
||||
11: '0',
|
||||
12: '-',
|
||||
13: '=',
|
||||
16: 'q',
|
||||
17: 'w',
|
||||
18: 'e',
|
||||
19: 'r',
|
||||
20: 't',
|
||||
21: 'y',
|
||||
22: 'u',
|
||||
23: 'i',
|
||||
24: 'o',
|
||||
25: 'p',
|
||||
26: '[',
|
||||
27: ']',
|
||||
30: 'a',
|
||||
31: 's',
|
||||
32: 'd',
|
||||
33: 'f',
|
||||
34: 'g',
|
||||
35: 'h',
|
||||
36: 'j',
|
||||
37: 'k',
|
||||
38: 'l',
|
||||
39: ';',
|
||||
40: '\'',
|
||||
43: '\\',
|
||||
44: 'z',
|
||||
45: 'x',
|
||||
46: 'c',
|
||||
47: 'v',
|
||||
48: 'b',
|
||||
49: 'n',
|
||||
50: 'm',
|
||||
51: ',',
|
||||
52: '.',
|
||||
53: '/'
|
||||
}
|
||||
WAIT_SECONDS = 2
|
||||
|
||||
def __init__(self, verbose=False):
|
||||
if evdev is None:
|
||||
raise TypeError('KeyboardRecorderFromEvdev needs to be used on Linux ' +
|
||||
'(or POSIX compatible) system. Instead, You can try it ' +
|
||||
'on your console.')
|
||||
KeyboardRecorder.__init__(self, verbose)
|
||||
self.log('Input from evdev')
|
||||
keyboards = []
|
||||
ecode_ev_key = evdev.ecodes.ecodes['EV_KEY']
|
||||
ecode_key_esc = evdev.ecodes.ecodes['KEY_ESC']
|
||||
for device in [evdev.InputDevice(fn) for fn in evdev.list_devices()]:
|
||||
# TODO(shimazu): Consider more solid criteria to get 'regular' keyboards.
|
||||
if ecode_ev_key in device.capabilities() and \
|
||||
ecode_key_esc in device.capabilities()[ecode_ev_key]:
|
||||
keyboards.append(device)
|
||||
if len(keyboards) == 0:
|
||||
raise IOError('No keyboard found.')
|
||||
self._keyboards = keyboards
|
||||
for keyboard in keyboards:
|
||||
self.log('----')
|
||||
self.log(keyboard)
|
||||
self.log('name: {0}'.format(keyboard.name))
|
||||
self.log('phys: {0}'.format(keyboard.phys))
|
||||
self.log('repeat: {0}'.format(keyboard.repeat))
|
||||
self.log('info: {0}'.format(keyboard.info))
|
||||
self.log(keyboard.capabilities(verbose=True))
|
||||
|
||||
def record(self):
|
||||
"""
|
||||
Returns a tuple of |data| and |command|.
|
||||
|data|: an array of tuples of keys and time from the first character. None
|
||||
if the input is non-alphabet/numeric/symbols like ENTER, arrows etc.
|
||||
|command|: Commands like "KEY_ENTER" or None if |data| is valid.
|
||||
"""
|
||||
start_time = None
|
||||
last_time = None
|
||||
data = []
|
||||
while True:
|
||||
# TODO(shimazu): Check inputs from all keyboards.
|
||||
event = self._keyboards[0].read_one()
|
||||
now = datetime.datetime.now()
|
||||
if last_time and (now - last_time).total_seconds() > self.WAIT_SECONDS:
|
||||
break
|
||||
if event is None:
|
||||
continue
|
||||
name = evdev.ecodes.bytype[event.type][event.code]
|
||||
ev_type = evdev.ecodes.EV[event.type]
|
||||
if ev_type != 'EV_KEY':
|
||||
continue
|
||||
# Keyboard input
|
||||
self.log('----')
|
||||
self.log(event)
|
||||
self.log('name: {}'.format(name))
|
||||
self.log('type: {}'.format(ev_type))
|
||||
# Check if the event is from releasing the button
|
||||
if event.value == 0:
|
||||
continue
|
||||
|
||||
# It may be a non-alphabet/numeric/symbol key. Return it as a command.
|
||||
if event.code not in self.KEYS:
|
||||
if start_time is not None:
|
||||
continue
|
||||
return (None, name)
|
||||
last_time = now
|
||||
if start_time is None:
|
||||
start_time = now
|
||||
elapsed_ms = int((now - start_time).total_seconds() * 1000)
|
||||
data.append((self.KEYS[event.code], elapsed_ms))
|
||||
return (data, None)
|
||||
|
||||
InputSource = Enum('InputSource', 'EVDEV CONSOLE')
|
||||
|
||||
def create_keyboard_recorder(verbose=False, source=None):
|
||||
"""Creates KeyboardRecorder.
|
||||
|
||||
Args:
|
||||
verbose: Print the detail of input when it's true.
|
||||
source: InputSource.EVDEV, InputSource.CONSOLE or None (default)
|
||||
|
||||
Returns:
|
||||
recorder: Corresponding KeyboardRecorder. If |source| is None, returns
|
||||
KeyboardRecorderFromConsole when stdin is attached to a console (isatty is
|
||||
true). Otherwise, returns KeyboardRecorderFromEvdev.
|
||||
"""
|
||||
if source == InputSource.CONSOLE:
|
||||
return KeyboardRecorderFromConsole(verbose=verbose)
|
||||
if source == InputSource.EVDEV:
|
||||
return KeyboardRecorderFromEvdev(verbose=verbose)
|
||||
if sys.__stdin__.isatty():
|
||||
return KeyboardRecorderFromConsole(verbose=verbose)
|
||||
return KeyboardRecorderFromEvdev(verbose=verbose)
|
||||
82
mozc-nazoru/src/nazoru/led.py
Normal file
82
mozc-nazoru/src/nazoru/led.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class LEDBase():
|
||||
def __init__(self, pin):
|
||||
self._pin = pin
|
||||
def on(self):
|
||||
pass
|
||||
def off(self):
|
||||
pass
|
||||
def blink(self, interval):
|
||||
pass
|
||||
|
||||
try:
|
||||
import RPi.GPIO as GPIO
|
||||
import threading
|
||||
|
||||
class LED(LEDBase):
|
||||
ON = 'ON'
|
||||
OFF = 'OFF'
|
||||
|
||||
def __init__(self, pin):
|
||||
GPIO.setmode(GPIO.BOARD)
|
||||
self._pin = pin
|
||||
self._lock = threading.Lock()
|
||||
self._timer = None
|
||||
GPIO.setup(pin, GPIO.OUT)
|
||||
self.off()
|
||||
|
||||
def on(self):
|
||||
with self._lock:
|
||||
self._state = self.ON
|
||||
GPIO.output(self._pin, False)
|
||||
self._ensure_stop_timer()
|
||||
|
||||
def off(self):
|
||||
with self._lock:
|
||||
self._state = self.OFF
|
||||
GPIO.output(self._pin, True)
|
||||
self._ensure_stop_timer()
|
||||
|
||||
def blink(self, interval):
|
||||
self._ensure_stop_timer()
|
||||
def toggle():
|
||||
self._timer = None
|
||||
if self._state == self.ON:
|
||||
self.off()
|
||||
else:
|
||||
self.on()
|
||||
self._timer = threading.Timer(interval, toggle)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
toggle()
|
||||
|
||||
def _ensure_stop_timer(self):
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
class LED(LEDBase):
|
||||
pass
|
||||
|
||||
LED_BLUE = LED(38)
|
||||
LED_RED = LED(40)
|
||||
410
mozc-nazoru/src/nazoru/lib.py
Normal file
410
mozc-nazoru/src/nazoru/lib.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Nazoru input library.
|
||||
This is a collection of methods to preprocess input stroke data before any
|
||||
training starts.
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
import cairocffi as cairo
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
SCOPE = 'Nazorunet'
|
||||
INPUT_NODE_NAME = 'inputs'
|
||||
OUTPUT_NODE_NAME = SCOPE + '/Predictions/Reshape_1'
|
||||
KANAS = (u'あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほ'
|
||||
u'まみむめもやゆよらりるれろわゐんゑを'
|
||||
u'abcdefghijklmnopqrstuvwxyz1234567890'
|
||||
u'♡ーずぐ')
|
||||
KEYS = ('a', 'i', 'u', 'e', 'o',
|
||||
'ka', 'ki', 'ku', 'ke', 'ko',
|
||||
'sa', 'si', 'su', 'se', 'so',
|
||||
'ta', 'ti', 'tu', 'te', 'to',
|
||||
'na', 'ni', 'nu', 'ne', 'no',
|
||||
'ha', 'hi', 'hu', 'he', 'ho',
|
||||
'ma', 'mi', 'mu', 'me', 'mo',
|
||||
'ya', 'yu', 'yo',
|
||||
'ra', 'ri', 'ru', 're', 'ro',
|
||||
'wa', 'wi', 'nn', 'we', 'wo',
|
||||
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
|
||||
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
|
||||
'1', '2', '3', '4', '5', '6', '7', '8', '9', '0',
|
||||
'ha-to', '-', 'zu', 'gu')
|
||||
|
||||
class KeyboardArrangement(Enum):
|
||||
"""Enum for keyboard arrangements.
|
||||
"""
|
||||
qwerty_jis = [
|
||||
u'1234567890-^¥',
|
||||
u'qwertyuiop@[',
|
||||
u'asdfghjkl;:]',
|
||||
u'zxcvbnm,./_',
|
||||
]
|
||||
|
||||
|
||||
def key2pos(key, arrangement=KeyboardArrangement.qwerty_jis.value, offset=0.5):
|
||||
"""Returns the key position.
|
||||
|
||||
Args:
|
||||
key (string): Key to get position.
|
||||
arrangement (list): Keyboard arrangement.
|
||||
offset (number): How much the keys are shifting by row.
|
||||
|
||||
Returns:
|
||||
position (tuple(number, number)): Position (x, y).
|
||||
|
||||
"""
|
||||
|
||||
for i, row in enumerate(arrangement):
|
||||
if key in row:
|
||||
y = i
|
||||
x = row.index(key) + i * offset
|
||||
return (x, y)
|
||||
return None
|
||||
|
||||
|
||||
def keydowns2points(keydowns):
|
||||
"""Translates keydowns to points.
|
||||
|
||||
Args:
|
||||
keydowns: [(key, t), ...] List of keydowns.
|
||||
|
||||
Returns:
|
||||
points: [(x, y, t), ...] List of points.
|
||||
"""
|
||||
|
||||
points = []
|
||||
for keydown in keydowns:
|
||||
pos = key2pos(keydown[0])
|
||||
if pos:
|
||||
points.append((pos[0], pos[1], keydown[1]))
|
||||
return points
|
||||
|
||||
|
||||
def normalize_x(x):
|
||||
"""Normalizes position.
|
||||
|
||||
Args:
|
||||
x (list): [[x, y, t], ...] List of points to normalize the position (x, y)
|
||||
into 0-1 range.
|
||||
|
||||
Returns:
|
||||
x (list): [[x', y', t], ...] List of points with the normalized potision
|
||||
(x', y').
|
||||
"""
|
||||
|
||||
x = np.array(x)
|
||||
max_ = np.max(x[:, :2], axis=0)
|
||||
min_ = np.min(x[:, :2], axis=0)
|
||||
x[:, :2] = (x[:, :2] - min_) / (max_ - min_)
|
||||
return x
|
||||
|
||||
|
||||
def pendown_encode(x_diff, sigma=1.6):
|
||||
"""Encodes time difference into pendown state.
|
||||
|
||||
Args:
|
||||
x_diff (list): [[dx, dy, dt], ...] List of diffs to encode.
|
||||
|
||||
Returns:
|
||||
x_diff_encoded (list): [[dx, dy, dt, pendown], ...] Encoded list of diffs.
|
||||
"""
|
||||
|
||||
thres = np.mean(x_diff[:,2]) + sigma * np.std(x_diff[:,2])
|
||||
x_diff_encoded = np.concatenate((
|
||||
x_diff,
|
||||
[[0] if dt_i > thres else [1] for dt_i in x_diff[:, 2]]
|
||||
), axis=1)
|
||||
return x_diff_encoded
|
||||
|
||||
|
||||
def surface_to_array(surface):
|
||||
"""Returns image array from cairo surface.
|
||||
|
||||
Args:
|
||||
surface: Cairo surface to translate.
|
||||
"""
|
||||
buf = BytesIO()
|
||||
surface.write_to_png(buf)
|
||||
png_string = buf.getvalue()
|
||||
im = Image.open(BytesIO(png_string))
|
||||
imdata = np.asarray(im.convert('L'))
|
||||
return imdata
|
||||
|
||||
|
||||
def get_direction(diff):
|
||||
"""Returns directions and weights for 8-directional features.
|
||||
|
||||
For more detail, see
|
||||
|
||||
- Bai, Zhen-Long, and Qiang Huo. "A study on the use of 8-directional features
|
||||
for online handwritten Chinese character recognition."
|
||||
- Liu, Cheng-Lin, and Xiang-Dong Zhou. "Online Japanese character recognition
|
||||
using trajectory-based normalization and direction feature extraction."
|
||||
|
||||
Weight is halved for pen-up states.
|
||||
|
||||
Args:
|
||||
diff (numpy.array): Encoded diff vector (dx, dy, dt, pendown).
|
||||
|
||||
Returns:
|
||||
First direction (Right (0), Down, (2), Left (4), Up (6)) and its weight, and
|
||||
Second direction (Bottom right (1), Bottom left (3), Up left (5), Up right
|
||||
(7)) and its weight.
|
||||
|
||||
"""
|
||||
|
||||
if np.abs(diff[0]) >= np.abs(diff[1]):
|
||||
if diff[0] >= 0:
|
||||
direction1 = 0
|
||||
else:
|
||||
direction1 = 4
|
||||
else:
|
||||
if diff[1] >= 0:
|
||||
direction1 = 2
|
||||
else:
|
||||
direction1 = 6
|
||||
|
||||
if diff[0] >= 0:
|
||||
if diff [1] >= 0:
|
||||
direction2 = 1
|
||||
else:
|
||||
direction2 = 7
|
||||
else:
|
||||
if diff [1] >= 0:
|
||||
direction2 = 3
|
||||
else:
|
||||
direction2 = 5
|
||||
length = np.linalg.norm(diff[:2])
|
||||
if length == 0: return 0, 0, 1, 0
|
||||
weight1 = np.abs(np.abs(diff[0]) - np.abs(diff[1])) / length
|
||||
weight2 = np.sqrt(2) * min(np.abs(diff[0]), np.abs(diff[1])) / length
|
||||
if diff[3] == 0:
|
||||
weight1 /= 2
|
||||
weight2 /= 2
|
||||
return direction1, weight1, direction2, weight2
|
||||
|
||||
|
||||
def generate_images(x_norm, x_diff_encoded, directional_feature,
|
||||
temporal_feature, scale, stroke_width):
|
||||
"""Generates image array from strokes.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
directional_feature (boolean): True when using direcitonal feature.
|
||||
temporal_feature (boolean): True when using temporal feature.
|
||||
scale (int): Scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
images (numpy.array): An array of images. Each image should have a shape of
|
||||
(scale, scale). Eight images will be added into the returned array if
|
||||
|directional_feature| is True, otherwise one original image will be
|
||||
added. Also, two images will be generated if |temporal_feature| is True.
|
||||
For example, the shape of |images| will be (scale, scale, 10) when both of
|
||||
options are True.
|
||||
"""
|
||||
|
||||
if directional_feature:
|
||||
images = generate_image_direct_decomp(
|
||||
x_norm, x_diff_encoded, scale, stroke_width)
|
||||
else:
|
||||
images = generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width)
|
||||
if temporal_feature:
|
||||
image = generate_image_temporal(
|
||||
x_norm, x_diff_encoded, scale, stroke_width, inversed=False)
|
||||
images = np.concatenate((images, image), axis=-1)
|
||||
image = generate_image_temporal(
|
||||
x_norm, x_diff_encoded, scale, stroke_width, inversed=True)
|
||||
images = np.concatenate((images, image), axis=-1)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def generate_image_direct_decomp(x_norm, x_diff_encoded, scale, stroke_width):
|
||||
"""Generates image array from strokes using direction feature.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
scale (int): scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
image (numpy.array): Image array with a shape of (scale, scale, 8).
|
||||
"""
|
||||
|
||||
surfaces = [cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
for _ in range(8)]
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
direction1, weight1, direction2, weight2 = get_direction(diff)
|
||||
|
||||
ctx = cairo.Context(surfaces[direction1])
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
ctx.set_source_rgba(1, 1, 1, weight1)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
ctx = cairo.Context(surfaces[direction2])
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
ctx.set_source_rgba(1, 1, 1, weight2)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
|
||||
return np.array([
|
||||
surface_to_array(surface) for surface in surfaces]).transpose(1, 2, 0)
|
||||
|
||||
|
||||
def generate_image_plain(x_norm, x_diff_encoded, scale, stroke_width):
|
||||
"""Generates image array from strokes without direction feature.
|
||||
|
||||
Args:
|
||||
x_norm: [(x', y', t), ...] Normalized points.
|
||||
x_diff_encoded: [(dx, dy, dt, pendown), ...] Normalized diffs.
|
||||
scale (int): scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
image (numpy.array): Image array with a shape of (scale, scale, 1).
|
||||
"""
|
||||
|
||||
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
ctx = cairo.Context(surface)
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
if diff[3] == 1:
|
||||
ctx.set_source_rgba(1, 1, 1, 1)
|
||||
else:
|
||||
ctx.set_source_rgba(1, 1, 1, 0.5)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
|
||||
return surface_to_array(surface).reshape(scale, scale, 1)
|
||||
|
||||
|
||||
def generate_image_temporal(x_norm, x_diff_encoded, scale, stroke_width,
|
||||
steepness=2, inversed=False):
|
||||
surface = cairo.ImageSurface(cairo.FORMAT_A8, scale, scale)
|
||||
|
||||
curr_x = x_norm[0][0]
|
||||
curr_y = x_norm[0][1]
|
||||
spent_t = 0
|
||||
|
||||
for i, diff in enumerate(x_diff_encoded):
|
||||
ctx = cairo.Context(surface)
|
||||
ctx.move_to(curr_x * scale, curr_y * scale)
|
||||
ctx.set_line_width(stroke_width)
|
||||
weight = 1 - spent_t / x_norm[-1][2]
|
||||
if inversed: weight = 1 - weight
|
||||
weight = max(weight, 0) ** steepness
|
||||
if diff[3] == 0: weight /= 2
|
||||
ctx.set_source_rgba(1, 1, 1, weight)
|
||||
ctx.line_to((curr_x + diff[0]) * scale, (curr_y + diff[1]) * scale)
|
||||
ctx.stroke()
|
||||
|
||||
curr_x += diff[0]
|
||||
curr_y += diff[1]
|
||||
spent_t += diff[2]
|
||||
return surface_to_array(surface).reshape(scale, scale, 1)
|
||||
|
||||
|
||||
def split_data(x, t, val_rate, test_rate):
|
||||
"""Splits data into training, validation, and testing data.
|
||||
|
||||
Args:
|
||||
x: Data to split.
|
||||
t: Label to split.
|
||||
val_rate: What percentage of data to use as a validation set.
|
||||
test_rate: What percentage of data to use as a testing set.
|
||||
|
||||
Returns:
|
||||
train_x: Training inputs.
|
||||
train_t: Training labels.
|
||||
val_x: Validation inputs.
|
||||
val_t: Validation labels.
|
||||
test_x: Testing inputs.
|
||||
test_t: Testing labels.
|
||||
"""
|
||||
|
||||
n = x.shape[0]
|
||||
train_x = x[:int(n * (1 - val_rate - test_rate))]
|
||||
train_t = t[:int(n * (1 - val_rate - test_rate))]
|
||||
val_x= x[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
|
||||
val_t = t[int(n * (1 - val_rate - test_rate)):int(n * (1 - test_rate))]
|
||||
test_x = x[int(n * (1 - test_rate)):]
|
||||
test_t = t[int(n * (1 - test_rate)):]
|
||||
return train_x, train_t, val_x, val_t, test_x, test_t
|
||||
|
||||
|
||||
def keydowns2image(keydowns, directional_feature, temporal_feature, scale=16,
|
||||
stroke_width=2):
|
||||
"""Converts a list of keydowns into image.
|
||||
|
||||
Args:
|
||||
keydowns: [(key, t), ...] Training data as a list of keydowns.
|
||||
directional_feature (boolean): True when using directional feature.
|
||||
temporal_feature (boolean): True when using temporal feature.
|
||||
scale (int): Scale of the image.
|
||||
stroke_width (int): Brush thickness to draw.
|
||||
|
||||
Returns:
|
||||
X_im: Image dataset in numpy array format. The shape differs by used
|
||||
features.
|
||||
(directional=True, temporal=True) => (scale, scale, 10)
|
||||
(directional=True, temporal=False) => (scale, scale, 8)
|
||||
(directional=False, temporal=True) => (scale, scale, 3)
|
||||
(directional=False, temporal=False) => (scale, scale, 1)
|
||||
"""
|
||||
|
||||
# Translate keys to 2D points. {(key, t), ...} -> {(x, y, t), ...}
|
||||
X = keydowns2points(keydowns)
|
||||
|
||||
# 0-1 normalization
|
||||
X_norm = normalize_x(X)
|
||||
|
||||
# Take difference. {(x, y, t), ...} -> {(dx, dy, dt), ...}.
|
||||
X_diff = np.diff(X_norm, axis=0)
|
||||
|
||||
# Encode pendown state. {(dx, dy, dt), ...} -> {(dx, dy, dt, pendown), ...}
|
||||
X_diff_encoded = pendown_encode(X_diff)
|
||||
|
||||
# Render into images.
|
||||
X_im = generate_images(X_norm, X_diff_encoded, directional_feature,
|
||||
temporal_feature, scale, stroke_width) / 255.
|
||||
|
||||
return X_im
|
||||
30
mozc-nazoru/src/nazoru/nazoru_test.py
Normal file
30
mozc-nazoru/src/nazoru/nazoru_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import nazoru
|
||||
|
||||
class TestNazoru(unittest.TestCase):
|
||||
"""Test class for nazoru.py"""
|
||||
|
||||
def test_key2pos(self):
|
||||
pos = nazoru.key2pos('a', ['abc', 'def'], 0.5)
|
||||
self.assertEqual((0,0), pos)
|
||||
pos = nazoru.key2pos('e', ['abc', 'def'], 0.5)
|
||||
self.assertEqual((1.5,1), pos)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
259
mozc-nazoru/src/nazoru/nazorunet.py
Normal file
259
mozc-nazoru/src/nazoru/nazorunet.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""NazoruNet.
|
||||
|
||||
NazoruNet is a customized version of MobileNet architecture and can be used to
|
||||
predict input characters from visualized trace.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
from tensorflow.contrib import slim
|
||||
import tensorflow as tf
|
||||
|
||||
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
|
||||
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
|
||||
|
||||
|
||||
def mobilenet_v1_base(inputs,
|
||||
final_endpoint,
|
||||
conv_defs,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
bn_decay=0.95,
|
||||
output_stride=None,
|
||||
scope=None):
|
||||
"""Mobilenet v1.
|
||||
Constructs a Mobilenet v1 network from inputs to the given final endpoint.
|
||||
Customized to accept batch normalization decay parameter to accelerate
|
||||
learning speed.
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
final_endpoint: specifies the endpoint to construct the network up to. It
|
||||
can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
|
||||
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise,
|
||||
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
|
||||
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
|
||||
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'].
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
bn_decay: Decay parameter for batch normalization layer.
|
||||
output_stride: An integer that specifies the requested ratio of input to
|
||||
output spatial resolution. If not None, then we invoke atrous convolution
|
||||
if necessary to prevent the network from reducing the spatial resolution
|
||||
of the activation maps. Allowed values are 8 (accurate fully convolutional
|
||||
mode), 16 (fast fully convolutional mode), 32 (classification mode).
|
||||
scope: Optional variable_scope.
|
||||
Returns:
|
||||
tensor_out: output tensor corresponding to the final_endpoint.
|
||||
end_points: a set of activations for external use, for example summaries or
|
||||
losses.
|
||||
Raises:
|
||||
ValueError: if final_endpoint is not set to one of the predefined values,
|
||||
or depth_multiplier <= 0, or the target output_stride is not
|
||||
allowed.
|
||||
"""
|
||||
depth = lambda d: max(int(d * depth_multiplier), min_depth)
|
||||
end_points = {}
|
||||
|
||||
# Used to find thinned depths for each layer.
|
||||
if depth_multiplier <= 0:
|
||||
raise ValueError('depth_multiplier is not greater than zero.')
|
||||
|
||||
if output_stride is not None and output_stride not in [8, 16, 32]:
|
||||
raise ValueError('Only allowed output_stride values are 8, 16, 32.')
|
||||
|
||||
with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
|
||||
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'):
|
||||
# The current_stride variable keeps track of the output stride of the
|
||||
# activations, i.e., the running product of convolution strides up to the
|
||||
# current network layer. This allows us to invoke atrous convolution
|
||||
# whenever applying the next convolution would result in the activations
|
||||
# having output stride larger than the target output_stride.
|
||||
current_stride = 1
|
||||
|
||||
# The atrous convolution rate parameter.
|
||||
rate = 1
|
||||
|
||||
net = inputs
|
||||
for i, conv_def in enumerate(conv_defs):
|
||||
end_point_base = 'Conv2d_%d' % i
|
||||
|
||||
if output_stride is not None and current_stride == output_stride:
|
||||
# If we have reached the target output_stride, then we need to employ
|
||||
# atrous convolution with stride=1 and multiply the atrous rate by the
|
||||
# current unit's stride for use in subsequent layers.
|
||||
layer_stride = 1
|
||||
layer_rate = rate
|
||||
rate *= conv_def.stride
|
||||
else:
|
||||
layer_stride = conv_def.stride
|
||||
layer_rate = 1
|
||||
current_stride *= conv_def.stride
|
||||
|
||||
if isinstance(conv_def, Conv):
|
||||
end_point = end_point_base
|
||||
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
|
||||
stride=conv_def.stride,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
elif isinstance(conv_def, DepthSepConv):
|
||||
end_point = end_point_base + '_depthwise'
|
||||
|
||||
# By passing filters=None
|
||||
# separable_conv2d produces only a depthwise convolution layer
|
||||
net = slim.separable_conv2d(net, None, conv_def.kernel,
|
||||
depth_multiplier=1,
|
||||
stride=layer_stride,
|
||||
rate=layer_rate,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
|
||||
end_point = end_point_base + '_pointwise'
|
||||
|
||||
net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
|
||||
stride=1,
|
||||
normalizer_fn=slim.batch_norm,
|
||||
normalizer_params={'decay': bn_decay},
|
||||
scope=end_point)
|
||||
|
||||
end_points[end_point] = net
|
||||
if end_point == final_endpoint:
|
||||
return net, end_points
|
||||
else:
|
||||
raise ValueError('Unknown convolution type %s for layer %d'
|
||||
% (conv_def.ltype, i))
|
||||
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
||||
|
||||
|
||||
def nazorunet(inputs,
|
||||
num_classes,
|
||||
conv_defs,
|
||||
dropout_keep_prob=0.999,
|
||||
is_training=True,
|
||||
min_depth=8,
|
||||
depth_multiplier=1.0,
|
||||
prediction_fn=tf.contrib.layers.softmax,
|
||||
spatial_squeeze=True,
|
||||
reuse=None,
|
||||
scope='NazoruNet',
|
||||
global_pool=False):
|
||||
"""Customized MobileNet model for Nazoru Input.
|
||||
Args:
|
||||
inputs: a tensor of shape [batch_size, height, width, channels].
|
||||
num_classes: number of predicted classes. If 0 or None, the logits layer
|
||||
is omitted and the input features to the logits layer (before dropout)
|
||||
are returned instead.
|
||||
dropout_keep_prob: the percentage of activation values that are retained.
|
||||
is_training: whether is training or not.
|
||||
min_depth: Minimum depth value (number of channels) for all convolution ops.
|
||||
Enforced when depth_multiplier < 1, and not an active constraint when
|
||||
depth_multiplier >= 1.
|
||||
depth_multiplier: Float multiplier for the depth (number of channels)
|
||||
for all convolution ops. The value must be greater than zero. Typical
|
||||
usage will be to set this value in (0, 1) to reduce the number of
|
||||
parameters or computation cost of the model.
|
||||
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
|
||||
prediction_fn: a function to get predictions out of logits.
|
||||
spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
|
||||
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
|
||||
reuse: whether or not the network and its variables should be reused. To be
|
||||
able to reuse 'scope' must be given.
|
||||
scope: Optional variable_scope.
|
||||
global_pool: Optional boolean flag to control the avgpooling before the
|
||||
logits layer. If false or unset, pooling is done with a fixed window
|
||||
that reduces default-sized inputs to 1x1, while larger inputs lead to
|
||||
larger outputs. If true, any input size is pooled down to 1x1.
|
||||
Returns:
|
||||
net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
|
||||
is a non-zero integer, or the non-dropped-out input to the logits layer
|
||||
if num_classes is 0 or None.
|
||||
end_points: a dictionary from components of the network to the corresponding
|
||||
activation.
|
||||
Raises:
|
||||
ValueError: Input rank is invalid.
|
||||
"""
|
||||
input_shape = inputs.get_shape().as_list()
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
|
||||
len(input_shape))
|
||||
|
||||
with tf.variable_scope(scope, 'NazoruNet', [inputs], reuse=reuse) as scope:
|
||||
with slim.arg_scope([slim.batch_norm, slim.dropout],
|
||||
is_training=is_training):
|
||||
final_endpoint = 'Conv2d_%i_pointwise' % (len(conv_defs) - 1)
|
||||
net, end_points = mobilenet_v1_base(inputs, scope=scope,
|
||||
min_depth=min_depth,
|
||||
depth_multiplier=depth_multiplier,
|
||||
conv_defs=conv_defs,
|
||||
final_endpoint=final_endpoint)
|
||||
with tf.variable_scope('Logits'):
|
||||
if global_pool:
|
||||
# Global average pooling.
|
||||
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
|
||||
end_points['global_pool'] = net
|
||||
else:
|
||||
# Pooling with a fixed kernel size.
|
||||
kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
|
||||
net = slim.avg_pool2d(net, kernel_size, padding='VALID',
|
||||
scope='AvgPool_1a')
|
||||
end_points['AvgPool_1a'] = net
|
||||
if not num_classes:
|
||||
return net, end_points
|
||||
# 1 x 1 x 1024
|
||||
net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
|
||||
logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope='Conv2d_1c_1x1')
|
||||
if spatial_squeeze:
|
||||
logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
|
||||
end_points['Logits'] = logits
|
||||
if prediction_fn:
|
||||
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
||||
return logits, end_points
|
||||
|
||||
|
||||
def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
|
||||
"""Define kernel size which is automatically reduced for small input.
|
||||
If the shape of the input images is unknown at graph construction time this
|
||||
function assumes that the input images are large enough.
|
||||
Args:
|
||||
input_tensor: input tensor of size [batch_size, height, width, channels].
|
||||
kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
|
||||
Returns:
|
||||
a tensor with the kernel size.
|
||||
"""
|
||||
shape = input_tensor.get_shape().as_list()
|
||||
if shape[1] is None or shape[2] is None:
|
||||
kernel_size_out = kernel_size
|
||||
else:
|
||||
kernel_size_out = [min(shape[1], kernel_size[0]),
|
||||
min(shape[2], kernel_size[1])]
|
||||
return kernel_size_out
|
||||
64
mozc-nazoru/src/nazoru/predictor.py
Normal file
64
mozc-nazoru/src/nazoru/predictor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
from . import lib
|
||||
from . import utils
|
||||
import numpy as np
|
||||
|
||||
def _load_graph(model_file):
|
||||
graph = tf.Graph()
|
||||
graph_def = tf.GraphDef()
|
||||
with open(model_file, "rb") as f:
|
||||
graph_def.ParseFromString(f.read())
|
||||
with graph.as_default():
|
||||
tf.import_graph_def(graph_def)
|
||||
return graph
|
||||
|
||||
class NazoruPredictor():
|
||||
def __init__(self, model_file):
|
||||
graph = _load_graph(model_file)
|
||||
self._graph = graph
|
||||
self._input_operation = graph.get_operation_by_name(
|
||||
'import/' + lib.INPUT_NODE_NAME)
|
||||
self._output_operation = graph.get_operation_by_name(
|
||||
'import/' + lib.OUTPUT_NODE_NAME)
|
||||
|
||||
def _predict(self, data):
|
||||
with utils.Measure('inputs'):
|
||||
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
||||
inputs = np.expand_dims(inputs, axis=0)
|
||||
with utils.Measure('sess.run'):
|
||||
with tf.Session(graph=self._graph) as sess:
|
||||
result = sess.run(self._output_operation.outputs[0],
|
||||
{self._input_operation.outputs[0]: inputs})[0]
|
||||
return result
|
||||
def predict_top_n(self, data, n):
|
||||
"""Predict the charactor drawn by |data|.
|
||||
|
||||
Args:
|
||||
data: [(key, time)] |time| is elapsed time since the first character in ms.
|
||||
n: integer of the number of the return value.
|
||||
Returns:
|
||||
ans: [(kana, key, probability)] sorted by the probability.
|
||||
"""
|
||||
result = self._predict(data)
|
||||
ans = []
|
||||
for i in result.argsort()[::-1][:n]:
|
||||
ans.append((lib.KANAS[i], lib.KEYS[i], result[i]))
|
||||
return ans
|
||||
28
mozc-nazoru/src/nazoru/utils.py
Normal file
28
mozc-nazoru/src/nazoru/utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
class Measure():
|
||||
def __init__(self, tag):
|
||||
self._tag = tag
|
||||
|
||||
def __enter__(self):
|
||||
self._start = time.time()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
now = time.time()
|
||||
print('[{0}] {1} ms'.format(self._tag, (now - self._start)*1E3))
|
||||
Reference in New Issue
Block a user