feat: 添加grpc服务API接口,以及流式识别实现
This commit is contained in:
parent
c9868ff3ac
commit
6936457f56
|
@ -0,0 +1,182 @@
|
|||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2016-2099 Ailemon.net
|
||||
#
|
||||
# This file is part of ASRT Speech Recognition Tool.
|
||||
#
|
||||
# ASRT is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
# ASRT is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
@author: nl8590687
|
||||
ASRT语音识别基于gRPC协议的API服务器程序
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
|
||||
from assets.asrt_pb2_grpc import AsrtGrpcServiceServicer, add_AsrtGrpcServiceServicer_to_server
|
||||
from assets.asrt_pb2 import SpeechResponse, TextResponse
|
||||
from speech_model import ModelSpeech
|
||||
from speech_model_zoo import SpeechModel251BN
|
||||
from speech_features import Spectrogram
|
||||
from language_model3 import ModelLanguage
|
||||
from utils.ops import decode_wav_bytes
|
||||
|
||||
API_STATUS_CODE_OK = 200000 # OK
|
||||
API_STATUS_CODE_OK_PART = 206000 # 部分结果OK,用于stream
|
||||
API_STATUS_CODE_CLIENT_ERROR = 400000
|
||||
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
|
||||
API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持
|
||||
API_STATUS_CODE_SERVER_ERROR = 500000
|
||||
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错
|
||||
|
||||
parser = argparse.ArgumentParser(description='ASRT gRPC Protocol API Service')
|
||||
parser.add_argument('--listen', default='0.0.0.0', type=str, help='the network to listen')
|
||||
parser.add_argument('--port', default='20002', type=str, help='the port to listen')
|
||||
args = parser.parse_args()
|
||||
|
||||
AUDIO_LENGTH = 1600
|
||||
AUDIO_FEATURE_LENGTH = 200
|
||||
CHANNELS = 1
|
||||
# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块
|
||||
OUTPUT_SIZE = 1428
|
||||
sm251bn = SpeechModel251BN(
|
||||
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
|
||||
output_size=OUTPUT_SIZE
|
||||
)
|
||||
feat = Spectrogram()
|
||||
ms = ModelSpeech(sm251bn, feat, max_label_length=64)
|
||||
ms.load_model('save_models/' + sm251bn.get_model_name() + '.model.h5')
|
||||
|
||||
ml = ModelLanguage('model_language')
|
||||
ml.load_model()
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
|
||||
class ApiService(AsrtGrpcServiceServicer):
|
||||
'''
|
||||
继承AsrtGrpcServiceServicer,实现hello方法
|
||||
'''
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def Speech(self, request, context):
|
||||
'''
|
||||
具体实现Speech的方法, 并按照pb的返回对象构造SpeechResponse返回
|
||||
:param request:
|
||||
:param context:
|
||||
:return:
|
||||
'''
|
||||
wav_data = request.wav_data
|
||||
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
|
||||
channels=wav_data.channels, byte_width=wav_data.byte_width)
|
||||
result = ms.recognize_speech(wav_samples, wav_data.sample_rate)
|
||||
print("语音识别声学模型结果:", result)
|
||||
return SpeechResponse(status_code=API_STATUS_CODE_OK, status_message='',
|
||||
result_data=result)
|
||||
|
||||
def Language(self, request, context):
|
||||
'''
|
||||
具体实现Language的方法, 并按照pb的返回对象构造TextResponse返回
|
||||
:param request:
|
||||
:param context:
|
||||
:return:
|
||||
'''
|
||||
print('Language收到了请求:', request)
|
||||
result = ml.pinyin_to_text(list(request.pinyins))
|
||||
print('Language结果:', result)
|
||||
return TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
|
||||
text_result=result)
|
||||
|
||||
def All(self, request, context):
|
||||
'''
|
||||
具体实现All的方法, 并按照pb的返回对象构造TextResponse返回
|
||||
:param request:
|
||||
:param context:
|
||||
:return:
|
||||
'''
|
||||
wav_data = request.wav_data
|
||||
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
|
||||
channels=wav_data.channels, byte_width=wav_data.byte_width)
|
||||
result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate)
|
||||
result = ml.pinyin_to_text(result_speech)
|
||||
print("语音识别结果:", result)
|
||||
return TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
|
||||
text_result=result)
|
||||
|
||||
def Stream(self, request_iterator, context):
|
||||
'''
|
||||
具体实现Stream的方法, 并按照pb的返回对象构造TextResponse返回
|
||||
:param request:
|
||||
:param context:
|
||||
:return:
|
||||
'''
|
||||
result = list()
|
||||
tmp_result_last = list()
|
||||
beam_size = 100
|
||||
|
||||
for request in request_iterator:
|
||||
wav_data = request.wav_data
|
||||
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
|
||||
channels=wav_data.channels,
|
||||
byte_width=wav_data.byte_width)
|
||||
result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate)
|
||||
|
||||
for item_pinyin in result_speech:
|
||||
tmp_result = ml.pinyin_stream_decode(tmp_result_last, item_pinyin, beam_size)
|
||||
if len(tmp_result) == 0 and len(tmp_result_last) > 0:
|
||||
result.append(tmp_result_last[0][0])
|
||||
print("流式语音识别结果:", ''.join(result))
|
||||
yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
|
||||
text_result=''.join(result))
|
||||
result = list()
|
||||
|
||||
tmp_result = ml.pinyin_stream_decode([], item_pinyin, beam_size)
|
||||
tmp_result_last = tmp_result
|
||||
yield TextResponse(status_code=API_STATUS_CODE_OK_PART, status_message='',
|
||||
text_result=''.join(tmp_result[0][0]))
|
||||
|
||||
if len(tmp_result_last) > 0:
|
||||
result.append(tmp_result_last[0][0])
|
||||
print("流式语音识别结果:", ''.join(result))
|
||||
yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
|
||||
text_result=''.join(result))
|
||||
|
||||
|
||||
def run(host, port):
|
||||
'''
|
||||
gRPC API服务启动
|
||||
:return:
|
||||
'''
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
add_AsrtGrpcServiceServicer_to_server(ApiService(),server)
|
||||
server.add_insecure_port(''.join([host, ':', port]))
|
||||
server.start()
|
||||
print("start service...")
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run(host=args.listen, port=args.port)
|
|
@ -0,0 +1,55 @@
|
|||
/* Copyright 2016-2099 Ailemon.net
|
||||
|
||||
This file is part of ASRT Speech Recognition Tool.
|
||||
|
||||
ASRT is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
ASRT is distributed in the hope that it will be useful,
|
||||
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
||||
============================================================================ */
|
||||
|
||||
syntax = "proto3";
|
||||
package asrt;
|
||||
|
||||
//定义服务接口
|
||||
service AsrtGrpcService {
|
||||
rpc Speech (SpeechRequest) returns (SpeechResponse) {} //一个服务中可以定义多个接口,也就是多个函数功能
|
||||
rpc Language (LanguageRequest) returns (TextResponse) {}
|
||||
rpc All (SpeechRequest) returns (TextResponse) {}
|
||||
rpc Stream (stream SpeechRequest) returns (stream TextResponse) {}
|
||||
}
|
||||
|
||||
message SpeechRequest {
|
||||
WavData wav_data = 1;
|
||||
}
|
||||
|
||||
message SpeechResponse {
|
||||
int32 status_code = 1;
|
||||
string status_message = 2;
|
||||
repeated string result_data = 3; // 拼音结果
|
||||
}
|
||||
|
||||
message LanguageRequest {
|
||||
repeated string pinyins = 1;
|
||||
}
|
||||
|
||||
message TextResponse {
|
||||
int32 status_code = 1;
|
||||
string status_message = 2;
|
||||
string text_result = 3;
|
||||
}
|
||||
|
||||
message WavData{
|
||||
bytes samples = 1; // wav样本点字节
|
||||
int32 sample_rate = 2; // wav采样率
|
||||
int32 channels = 3; // wav通道数
|
||||
int32 byte_width = 4; // wav样本字节宽度
|
||||
}
|
|
@ -0,0 +1,336 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: asrt.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||
name='asrt.proto',
|
||||
package='asrt',
|
||||
syntax='proto3',
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
serialized_pb=b'\n\nasrt.proto\x12\x04\x61srt\"0\n\rSpeechRequest\x12\x1f\n\x08wav_data\x18\x01 \x01(\x0b\x32\r.asrt.WavData\"R\n\x0eSpeechResponse\x12\x13\n\x0bstatus_code\x18\x01 \x01(\x05\x12\x16\n\x0estatus_message\x18\x02 \x01(\t\x12\x13\n\x0bresult_data\x18\x03 \x03(\t\"\"\n\x0fLanguageRequest\x12\x0f\n\x07pinyins\x18\x01 \x03(\t\"P\n\x0cTextResponse\x12\x13\n\x0bstatus_code\x18\x01 \x01(\x05\x12\x16\n\x0estatus_message\x18\x02 \x01(\t\x12\x13\n\x0btext_result\x18\x03 \x01(\t\"U\n\x07WavData\x12\x0f\n\x07samples\x18\x01 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x02 \x01(\x05\x12\x10\n\x08\x63hannels\x18\x03 \x01(\x05\x12\x12\n\nbyte_width\x18\x04 \x01(\x05\x32\xec\x01\n\x0f\x41srtGrpcService\x12\x35\n\x06Speech\x12\x13.asrt.SpeechRequest\x1a\x14.asrt.SpeechResponse\"\x00\x12\x37\n\x08Language\x12\x15.asrt.LanguageRequest\x1a\x12.asrt.TextResponse\"\x00\x12\x30\n\x03\x41ll\x12\x13.asrt.SpeechRequest\x1a\x12.asrt.TextResponse\"\x00\x12\x37\n\x06Stream\x12\x13.asrt.SpeechRequest\x1a\x12.asrt.TextResponse\"\x00(\x01\x30\x01\x62\x06proto3'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
_SPEECHREQUEST = _descriptor.Descriptor(
|
||||
name='SpeechRequest',
|
||||
full_name='asrt.SpeechRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='wav_data', full_name='asrt.SpeechRequest.wav_data', index=0,
|
||||
number=1, type=11, cpp_type=10, label=1,
|
||||
has_default_value=False, default_value=None,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=20,
|
||||
serialized_end=68,
|
||||
)
|
||||
|
||||
|
||||
_SPEECHRESPONSE = _descriptor.Descriptor(
|
||||
name='SpeechResponse',
|
||||
full_name='asrt.SpeechResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='status_code', full_name='asrt.SpeechResponse.status_code', index=0,
|
||||
number=1, type=5, cpp_type=1, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='status_message', full_name='asrt.SpeechResponse.status_message', index=1,
|
||||
number=2, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='result_data', full_name='asrt.SpeechResponse.result_data', index=2,
|
||||
number=3, type=9, cpp_type=9, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=70,
|
||||
serialized_end=152,
|
||||
)
|
||||
|
||||
|
||||
_LANGUAGEREQUEST = _descriptor.Descriptor(
|
||||
name='LanguageRequest',
|
||||
full_name='asrt.LanguageRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='pinyins', full_name='asrt.LanguageRequest.pinyins', index=0,
|
||||
number=1, type=9, cpp_type=9, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=154,
|
||||
serialized_end=188,
|
||||
)
|
||||
|
||||
|
||||
_TEXTRESPONSE = _descriptor.Descriptor(
|
||||
name='TextResponse',
|
||||
full_name='asrt.TextResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='status_code', full_name='asrt.TextResponse.status_code', index=0,
|
||||
number=1, type=5, cpp_type=1, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='status_message', full_name='asrt.TextResponse.status_message', index=1,
|
||||
number=2, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='text_result', full_name='asrt.TextResponse.text_result', index=2,
|
||||
number=3, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=190,
|
||||
serialized_end=270,
|
||||
)
|
||||
|
||||
|
||||
_WAVDATA = _descriptor.Descriptor(
|
||||
name='WavData',
|
||||
full_name='asrt.WavData',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='samples', full_name='asrt.WavData.samples', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='sample_rate', full_name='asrt.WavData.sample_rate', index=1,
|
||||
number=2, type=5, cpp_type=1, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='channels', full_name='asrt.WavData.channels', index=2,
|
||||
number=3, type=5, cpp_type=1, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='byte_width', full_name='asrt.WavData.byte_width', index=3,
|
||||
number=4, type=5, cpp_type=1, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=272,
|
||||
serialized_end=357,
|
||||
)
|
||||
|
||||
_SPEECHREQUEST.fields_by_name['wav_data'].message_type = _WAVDATA
|
||||
DESCRIPTOR.message_types_by_name['SpeechRequest'] = _SPEECHREQUEST
|
||||
DESCRIPTOR.message_types_by_name['SpeechResponse'] = _SPEECHRESPONSE
|
||||
DESCRIPTOR.message_types_by_name['LanguageRequest'] = _LANGUAGEREQUEST
|
||||
DESCRIPTOR.message_types_by_name['TextResponse'] = _TEXTRESPONSE
|
||||
DESCRIPTOR.message_types_by_name['WavData'] = _WAVDATA
|
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||
|
||||
SpeechRequest = _reflection.GeneratedProtocolMessageType('SpeechRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _SPEECHREQUEST,
|
||||
'__module__' : 'asrt_pb2'
|
||||
# @@protoc_insertion_point(class_scope:asrt.SpeechRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(SpeechRequest)
|
||||
|
||||
SpeechResponse = _reflection.GeneratedProtocolMessageType('SpeechResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _SPEECHRESPONSE,
|
||||
'__module__' : 'asrt_pb2'
|
||||
# @@protoc_insertion_point(class_scope:asrt.SpeechResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(SpeechResponse)
|
||||
|
||||
LanguageRequest = _reflection.GeneratedProtocolMessageType('LanguageRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _LANGUAGEREQUEST,
|
||||
'__module__' : 'asrt_pb2'
|
||||
# @@protoc_insertion_point(class_scope:asrt.LanguageRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(LanguageRequest)
|
||||
|
||||
TextResponse = _reflection.GeneratedProtocolMessageType('TextResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TEXTRESPONSE,
|
||||
'__module__' : 'asrt_pb2'
|
||||
# @@protoc_insertion_point(class_scope:asrt.TextResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(TextResponse)
|
||||
|
||||
WavData = _reflection.GeneratedProtocolMessageType('WavData', (_message.Message,), {
|
||||
'DESCRIPTOR' : _WAVDATA,
|
||||
'__module__' : 'asrt_pb2'
|
||||
# @@protoc_insertion_point(class_scope:asrt.WavData)
|
||||
})
|
||||
_sym_db.RegisterMessage(WavData)
|
||||
|
||||
|
||||
|
||||
_ASRTGRPCSERVICE = _descriptor.ServiceDescriptor(
|
||||
name='AsrtGrpcService',
|
||||
full_name='asrt.AsrtGrpcService',
|
||||
file=DESCRIPTOR,
|
||||
index=0,
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
serialized_start=360,
|
||||
serialized_end=596,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Speech',
|
||||
full_name='asrt.AsrtGrpcService.Speech',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_SPEECHREQUEST,
|
||||
output_type=_SPEECHRESPONSE,
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Language',
|
||||
full_name='asrt.AsrtGrpcService.Language',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_LANGUAGEREQUEST,
|
||||
output_type=_TEXTRESPONSE,
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='All',
|
||||
full_name='asrt.AsrtGrpcService.All',
|
||||
index=2,
|
||||
containing_service=None,
|
||||
input_type=_SPEECHREQUEST,
|
||||
output_type=_TEXTRESPONSE,
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Stream',
|
||||
full_name='asrt.AsrtGrpcService.Stream',
|
||||
index=3,
|
||||
containing_service=None,
|
||||
input_type=_SPEECHREQUEST,
|
||||
output_type=_TEXTRESPONSE,
|
||||
serialized_options=None,
|
||||
create_key=_descriptor._internal_create_key,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_ASRTGRPCSERVICE)
|
||||
|
||||
DESCRIPTOR.services_by_name['AsrtGrpcService'] = _ASRTGRPCSERVICE
|
||||
|
||||
# @@protoc_insertion_point(module_scope)
|
|
@ -0,0 +1,168 @@
|
|||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import assets.asrt_pb2 as asrt__pb2
|
||||
|
||||
|
||||
class AsrtGrpcServiceStub(object):
|
||||
"""定义服务接口
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Speech = channel.unary_unary(
|
||||
'/asrt.AsrtGrpcService/Speech',
|
||||
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
|
||||
response_deserializer=asrt__pb2.SpeechResponse.FromString,
|
||||
)
|
||||
self.Language = channel.unary_unary(
|
||||
'/asrt.AsrtGrpcService/Language',
|
||||
request_serializer=asrt__pb2.LanguageRequest.SerializeToString,
|
||||
response_deserializer=asrt__pb2.TextResponse.FromString,
|
||||
)
|
||||
self.All = channel.unary_unary(
|
||||
'/asrt.AsrtGrpcService/All',
|
||||
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
|
||||
response_deserializer=asrt__pb2.TextResponse.FromString,
|
||||
)
|
||||
self.Stream = channel.stream_stream(
|
||||
'/asrt.AsrtGrpcService/Stream',
|
||||
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
|
||||
response_deserializer=asrt__pb2.TextResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class AsrtGrpcServiceServicer(object):
|
||||
"""定义服务接口
|
||||
"""
|
||||
|
||||
def Speech(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Language(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def All(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Stream(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsrtGrpcServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Speech': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Speech,
|
||||
request_deserializer=asrt__pb2.SpeechRequest.FromString,
|
||||
response_serializer=asrt__pb2.SpeechResponse.SerializeToString,
|
||||
),
|
||||
'Language': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Language,
|
||||
request_deserializer=asrt__pb2.LanguageRequest.FromString,
|
||||
response_serializer=asrt__pb2.TextResponse.SerializeToString,
|
||||
),
|
||||
'All': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.All,
|
||||
request_deserializer=asrt__pb2.SpeechRequest.FromString,
|
||||
response_serializer=asrt__pb2.TextResponse.SerializeToString,
|
||||
),
|
||||
'Stream': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.Stream,
|
||||
request_deserializer=asrt__pb2.SpeechRequest.FromString,
|
||||
response_serializer=asrt__pb2.TextResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'asrt.AsrtGrpcService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsrtGrpcService(object):
|
||||
"""定义服务接口
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def Speech(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Speech',
|
||||
asrt__pb2.SpeechRequest.SerializeToString,
|
||||
asrt__pb2.SpeechResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Language(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Language',
|
||||
asrt__pb2.LanguageRequest.SerializeToString,
|
||||
asrt__pb2.TextResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def All(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/All',
|
||||
asrt__pb2.SpeechRequest.SerializeToString,
|
||||
asrt__pb2.TextResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Stream(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_stream(request_iterator, target, '/asrt.AsrtGrpcService/Stream',
|
||||
asrt__pb2.SpeechRequest.SerializeToString,
|
||||
asrt__pb2.TextResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -0,0 +1,120 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2016-2099 Ailemon.net
|
||||
#
|
||||
# This file is part of ASRT Speech Recognition Tool.
|
||||
#
|
||||
# ASRT is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
# ASRT is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
@author: nl8590687
|
||||
ASRT语音识别asrserver grpc协议测试专用客户端
|
||||
'''
|
||||
|
||||
import time
|
||||
import grpc
|
||||
|
||||
from assets.asrt_pb2_grpc import AsrtGrpcServiceStub
|
||||
from assets.asrt_pb2 import SpeechRequest, LanguageRequest, WavData
|
||||
|
||||
from utils.ops import read_wav_bytes
|
||||
|
||||
def run_speech():
|
||||
'''
|
||||
请求ASRT服务Speech方法
|
||||
:return:
|
||||
'''
|
||||
conn=grpc.insecure_channel('127.0.0.1:20002')
|
||||
client = AsrtGrpcServiceStub(channel=conn)
|
||||
|
||||
wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav')
|
||||
print('sample_width:', sample_width)
|
||||
wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate,
|
||||
channels=channels, byte_width=sample_width)
|
||||
|
||||
request = SpeechRequest(wav_data=wav_data)
|
||||
time_stamp0=time.time()
|
||||
response = client.Speech(request)
|
||||
time_stamp1 = time.time()
|
||||
print('time:', time_stamp1-time_stamp0, 's')
|
||||
print("received:",response.result_data)
|
||||
|
||||
def run_lan():
|
||||
'''
|
||||
请求ASRT服务Language方法
|
||||
:return:
|
||||
'''
|
||||
conn=grpc.insecure_channel('127.0.0.1:20002')
|
||||
client = AsrtGrpcServiceStub(channel=conn)
|
||||
pinyin_data = ['ni3', 'hao3', 'ya5']
|
||||
request = LanguageRequest(pinyins=pinyin_data)
|
||||
time_stamp0=time.time()
|
||||
response = client.Language(request)
|
||||
time_stamp1 = time.time()
|
||||
print('time:', time_stamp1-time_stamp0, 's')
|
||||
print("received:",response.text_result)
|
||||
|
||||
def run_all():
|
||||
'''
|
||||
请求ASRT服务All方法
|
||||
:return:
|
||||
'''
|
||||
conn=grpc.insecure_channel('127.0.0.1:20002')
|
||||
client = AsrtGrpcServiceStub(channel=conn)
|
||||
|
||||
wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav')
|
||||
print('sample_width:', sample_width)
|
||||
wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate,
|
||||
channels=channels, byte_width=sample_width)
|
||||
|
||||
request = SpeechRequest(wav_data=wav_data)
|
||||
time_stamp0=time.time()
|
||||
response = client.All(request)
|
||||
time_stamp1=time.time()
|
||||
print("received:",response.text_result)
|
||||
print('time:', time_stamp1-time_stamp0, 's')
|
||||
|
||||
def run_stream():
|
||||
'''
|
||||
请求ASRT服务Stream方法
|
||||
:return:
|
||||
'''
|
||||
conn=grpc.insecure_channel('127.0.0.1:20002')
|
||||
client = AsrtGrpcServiceStub(channel=conn)
|
||||
|
||||
wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav')
|
||||
print('sample_width:', sample_width)
|
||||
wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate,
|
||||
channels=channels, byte_width=sample_width)
|
||||
|
||||
# 先制造一些客户端能发送的数据
|
||||
def make_some_data():
|
||||
for _ in range(1):
|
||||
time.sleep(1)
|
||||
yield SpeechRequest(wav_data=wav_data)
|
||||
|
||||
try:
|
||||
status_response = client.Stream(make_some_data())
|
||||
for ret in status_response:
|
||||
print("received:", ret.text_result, " , status:", ret.status_code)
|
||||
time.sleep(0.1)
|
||||
except Exception as any_exception:
|
||||
print(f'err in send_status:{any_exception}')
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#run_all()
|
||||
run_stream()
|
Loading…
Reference in New Issue