mirror of
https://github.com/google/mozc-devices.git
synced 2025-11-08 16:53:28 +03:00
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
|
|
import tensorflow as tf
|
|
from . import lib
|
|
from . import utils
|
|
import numpy as np
|
|
|
|
def _load_graph(model_file):
|
|
graph = tf.Graph()
|
|
graph_def = tf.GraphDef()
|
|
with open(model_file, "rb") as f:
|
|
graph_def.ParseFromString(f.read())
|
|
with graph.as_default():
|
|
tf.import_graph_def(graph_def)
|
|
return graph
|
|
|
|
class NazoruPredictor():
|
|
def __init__(self, model_file):
|
|
graph = _load_graph(model_file)
|
|
self._graph = graph
|
|
self._input_operation = graph.get_operation_by_name(
|
|
'import/' + lib.INPUT_NODE_NAME)
|
|
self._output_operation = graph.get_operation_by_name(
|
|
'import/' + lib.OUTPUT_NODE_NAME)
|
|
|
|
def _predict(self, data):
|
|
with utils.Measure('inputs'):
|
|
inputs = lib.keydowns2image(data, True, True, 16, 2)
|
|
inputs = np.expand_dims(inputs, axis=0)
|
|
with utils.Measure('sess.run'):
|
|
with tf.Session(graph=self._graph) as sess:
|
|
result = sess.run(self._output_operation.outputs[0],
|
|
{self._input_operation.outputs[0]: inputs})[0]
|
|
return result
|
|
def predict_top_n(self, data, n):
|
|
"""Predict the charactor drawn by |data|.
|
|
|
|
Args:
|
|
data: [(key, time)] |time| is elapsed time since the first character in ms.
|
|
n: integer of the number of the return value.
|
|
Returns:
|
|
ans: [(kana, key, probability)] sorted by the probability.
|
|
"""
|
|
result = self._predict(data)
|
|
ans = []
|
|
for i in result.argsort()[::-1][:n]:
|
|
ans.append((lib.KANAS[i], lib.KEYS[i], result[i]))
|
|
return ans
|