Files
mozc-devices/mozc-nazoru/src/nazoru/predictor.py
2018-03-31 23:12:46 +09:00

65 lines
2.1 KiB
Python

# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import tensorflow as tf
from . import lib
from . import utils
import numpy as np
def _load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
class NazoruPredictor():
def __init__(self, model_file):
graph = _load_graph(model_file)
self._graph = graph
self._input_operation = graph.get_operation_by_name(
'import/' + lib.INPUT_NODE_NAME)
self._output_operation = graph.get_operation_by_name(
'import/' + lib.OUTPUT_NODE_NAME)
def _predict(self, data):
with utils.Measure('inputs'):
inputs = lib.keydowns2image(data, True, True, 16, 2)
inputs = np.expand_dims(inputs, axis=0)
with utils.Measure('sess.run'):
with tf.Session(graph=self._graph) as sess:
result = sess.run(self._output_operation.outputs[0],
{self._input_operation.outputs[0]: inputs})[0]
return result
def predict_top_n(self, data, n):
"""Predict the charactor drawn by |data|.
Args:
data: [(key, time)] |time| is elapsed time since the first character in ms.
n: integer of the number of the return value.
Returns:
ans: [(kana, key, probability)] sorted by the probability.
"""
result = self._predict(data)
ans = []
for i in result.argsort()[::-1][:n]:
ans.append((lib.KANAS[i], lib.KEYS[i], result[i]))
return ans