diff --git a/asrserver_grpc.py b/asrserver_grpc.py new file mode 100644 index 0000000..c92751c --- /dev/null +++ b/asrserver_grpc.py @@ -0,0 +1,182 @@ + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +""" +@author: nl8590687 +ASRT语音识别基于gRPC协议的API服务器程序 +""" + +import argparse +import time +from concurrent import futures +import grpc + +from assets.asrt_pb2_grpc import AsrtGrpcServiceServicer, add_AsrtGrpcServiceServicer_to_server +from assets.asrt_pb2 import SpeechResponse, TextResponse +from speech_model import ModelSpeech +from speech_model_zoo import SpeechModel251BN +from speech_features import Spectrogram +from language_model3 import ModelLanguage +from utils.ops import decode_wav_bytes + +API_STATUS_CODE_OK = 200000 # OK +API_STATUS_CODE_OK_PART = 206000 # 部分结果OK,用于stream +API_STATUS_CODE_CLIENT_ERROR = 400000 +API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误 +API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持 +API_STATUS_CODE_SERVER_ERROR = 500000 +API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错 + +parser = argparse.ArgumentParser(description='ASRT gRPC Protocol API Service') +parser.add_argument('--listen', default='0.0.0.0', type=str, help='the network to listen') +parser.add_argument('--port', default='20002', type=str, help='the port to listen') +args = parser.parse_args() + +AUDIO_LENGTH = 1600 +AUDIO_FEATURE_LENGTH = 200 +CHANNELS = 1 +# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块 +OUTPUT_SIZE = 1428 +sm251bn = SpeechModel251BN( + input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS), + output_size=OUTPUT_SIZE + ) +feat = Spectrogram() +ms = ModelSpeech(sm251bn, feat, max_label_length=64) +ms.load_model('save_models/' + sm251bn.get_model_name() + '.model.h5') + +ml = ModelLanguage('model_language') +ml.load_model() + + + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + + +class ApiService(AsrtGrpcServiceServicer): + ''' + 继承AsrtGrpcServiceServicer,实现hello方法 + ''' + def __init__(self): + pass + + def Speech(self, request, context): + ''' + 具体实现Speech的方法, 并按照pb的返回对象构造SpeechResponse返回 + :param request: + :param context: + :return: + ''' + wav_data = request.wav_data + wav_samples = decode_wav_bytes(samples_data=wav_data.samples, + channels=wav_data.channels, byte_width=wav_data.byte_width) + result = ms.recognize_speech(wav_samples, wav_data.sample_rate) + print("语音识别声学模型结果:", result) + return SpeechResponse(status_code=API_STATUS_CODE_OK, status_message='', + result_data=result) + + def Language(self, request, context): + ''' + 具体实现Language的方法, 并按照pb的返回对象构造TextResponse返回 + :param request: + :param context: + :return: + ''' + print('Language收到了请求:', request) + result = ml.pinyin_to_text(list(request.pinyins)) + print('Language结果:', result) + return TextResponse(status_code=API_STATUS_CODE_OK, status_message='', + text_result=result) + + def All(self, request, context): + ''' + 具体实现All的方法, 并按照pb的返回对象构造TextResponse返回 + :param request: + :param context: + :return: + ''' + wav_data = request.wav_data + wav_samples = decode_wav_bytes(samples_data=wav_data.samples, + channels=wav_data.channels, byte_width=wav_data.byte_width) + result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate) + result = ml.pinyin_to_text(result_speech) + print("语音识别结果:", result) + return TextResponse(status_code=API_STATUS_CODE_OK, status_message='', + text_result=result) + + def Stream(self, request_iterator, context): + ''' + 具体实现Stream的方法, 并按照pb的返回对象构造TextResponse返回 + :param request: + :param context: + :return: + ''' + result = list() + tmp_result_last = list() + beam_size = 100 + + for request in request_iterator: + wav_data = request.wav_data + wav_samples = decode_wav_bytes(samples_data=wav_data.samples, + channels=wav_data.channels, + byte_width=wav_data.byte_width) + result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate) + + for item_pinyin in result_speech: + tmp_result = ml.pinyin_stream_decode(tmp_result_last, item_pinyin, beam_size) + if len(tmp_result) == 0 and len(tmp_result_last) > 0: + result.append(tmp_result_last[0][0]) + print("流式语音识别结果:", ''.join(result)) + yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='', + text_result=''.join(result)) + result = list() + + tmp_result = ml.pinyin_stream_decode([], item_pinyin, beam_size) + tmp_result_last = tmp_result + yield TextResponse(status_code=API_STATUS_CODE_OK_PART, status_message='', + text_result=''.join(tmp_result[0][0])) + + if len(tmp_result_last) > 0: + result.append(tmp_result_last[0][0]) + print("流式语音识别结果:", ''.join(result)) + yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='', + text_result=''.join(result)) + + +def run(host, port): + ''' + gRPC API服务启动 + :return: + ''' + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + add_AsrtGrpcServiceServicer_to_server(ApiService(),server) + server.add_insecure_port(''.join([host, ':', port])) + server.start() + print("start service...") + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == '__main__': + run(host=args.listen, port=args.port) diff --git a/assets/asrt.proto b/assets/asrt.proto new file mode 100644 index 0000000..f328fe8 --- /dev/null +++ b/assets/asrt.proto @@ -0,0 +1,55 @@ +/* Copyright 2016-2099 Ailemon.net + +This file is part of ASRT Speech Recognition Tool. + +ASRT is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. +ASRT is distributed in the hope that it will be useful, + +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with ASRT. If not, see . +============================================================================ */ + +syntax = "proto3"; +package asrt; + +//定义服务接口 +service AsrtGrpcService { + rpc Speech (SpeechRequest) returns (SpeechResponse) {} //一个服务中可以定义多个接口,也就是多个函数功能 + rpc Language (LanguageRequest) returns (TextResponse) {} + rpc All (SpeechRequest) returns (TextResponse) {} + rpc Stream (stream SpeechRequest) returns (stream TextResponse) {} +} + +message SpeechRequest { + WavData wav_data = 1; +} + +message SpeechResponse { + int32 status_code = 1; + string status_message = 2; + repeated string result_data = 3; // 拼音结果 +} + +message LanguageRequest { + repeated string pinyins = 1; +} + +message TextResponse { + int32 status_code = 1; + string status_message = 2; + string text_result = 3; +} + +message WavData{ + bytes samples = 1; // wav样本点字节 + int32 sample_rate = 2; // wav采样率 + int32 channels = 3; // wav通道数 + int32 byte_width = 4; // wav样本字节宽度 +} diff --git a/assets/asrt_pb2.py b/assets/asrt_pb2.py new file mode 100644 index 0000000..9d52cac --- /dev/null +++ b/assets/asrt_pb2.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: asrt.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='asrt.proto', + package='asrt', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\nasrt.proto\x12\x04\x61srt\"0\n\rSpeechRequest\x12\x1f\n\x08wav_data\x18\x01 \x01(\x0b\x32\r.asrt.WavData\"R\n\x0eSpeechResponse\x12\x13\n\x0bstatus_code\x18\x01 \x01(\x05\x12\x16\n\x0estatus_message\x18\x02 \x01(\t\x12\x13\n\x0bresult_data\x18\x03 \x03(\t\"\"\n\x0fLanguageRequest\x12\x0f\n\x07pinyins\x18\x01 \x03(\t\"P\n\x0cTextResponse\x12\x13\n\x0bstatus_code\x18\x01 \x01(\x05\x12\x16\n\x0estatus_message\x18\x02 \x01(\t\x12\x13\n\x0btext_result\x18\x03 \x01(\t\"U\n\x07WavData\x12\x0f\n\x07samples\x18\x01 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x02 \x01(\x05\x12\x10\n\x08\x63hannels\x18\x03 \x01(\x05\x12\x12\n\nbyte_width\x18\x04 \x01(\x05\x32\xec\x01\n\x0f\x41srtGrpcService\x12\x35\n\x06Speech\x12\x13.asrt.SpeechRequest\x1a\x14.asrt.SpeechResponse\"\x00\x12\x37\n\x08Language\x12\x15.asrt.LanguageRequest\x1a\x12.asrt.TextResponse\"\x00\x12\x30\n\x03\x41ll\x12\x13.asrt.SpeechRequest\x1a\x12.asrt.TextResponse\"\x00\x12\x37\n\x06Stream\x12\x13.asrt.SpeechRequest\x1a\x12.asrt.TextResponse\"\x00(\x01\x30\x01\x62\x06proto3' +) + + + + +_SPEECHREQUEST = _descriptor.Descriptor( + name='SpeechRequest', + full_name='asrt.SpeechRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='wav_data', full_name='asrt.SpeechRequest.wav_data', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=20, + serialized_end=68, +) + + +_SPEECHRESPONSE = _descriptor.Descriptor( + name='SpeechResponse', + full_name='asrt.SpeechResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='status_code', full_name='asrt.SpeechResponse.status_code', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='status_message', full_name='asrt.SpeechResponse.status_message', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='result_data', full_name='asrt.SpeechResponse.result_data', index=2, + number=3, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=70, + serialized_end=152, +) + + +_LANGUAGEREQUEST = _descriptor.Descriptor( + name='LanguageRequest', + full_name='asrt.LanguageRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='pinyins', full_name='asrt.LanguageRequest.pinyins', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=154, + serialized_end=188, +) + + +_TEXTRESPONSE = _descriptor.Descriptor( + name='TextResponse', + full_name='asrt.TextResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='status_code', full_name='asrt.TextResponse.status_code', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='status_message', full_name='asrt.TextResponse.status_message', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='text_result', full_name='asrt.TextResponse.text_result', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=190, + serialized_end=270, +) + + +_WAVDATA = _descriptor.Descriptor( + name='WavData', + full_name='asrt.WavData', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='samples', full_name='asrt.WavData.samples', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='sample_rate', full_name='asrt.WavData.sample_rate', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='channels', full_name='asrt.WavData.channels', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='byte_width', full_name='asrt.WavData.byte_width', index=3, + number=4, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=272, + serialized_end=357, +) + +_SPEECHREQUEST.fields_by_name['wav_data'].message_type = _WAVDATA +DESCRIPTOR.message_types_by_name['SpeechRequest'] = _SPEECHREQUEST +DESCRIPTOR.message_types_by_name['SpeechResponse'] = _SPEECHRESPONSE +DESCRIPTOR.message_types_by_name['LanguageRequest'] = _LANGUAGEREQUEST +DESCRIPTOR.message_types_by_name['TextResponse'] = _TEXTRESPONSE +DESCRIPTOR.message_types_by_name['WavData'] = _WAVDATA +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +SpeechRequest = _reflection.GeneratedProtocolMessageType('SpeechRequest', (_message.Message,), { + 'DESCRIPTOR' : _SPEECHREQUEST, + '__module__' : 'asrt_pb2' + # @@protoc_insertion_point(class_scope:asrt.SpeechRequest) + }) +_sym_db.RegisterMessage(SpeechRequest) + +SpeechResponse = _reflection.GeneratedProtocolMessageType('SpeechResponse', (_message.Message,), { + 'DESCRIPTOR' : _SPEECHRESPONSE, + '__module__' : 'asrt_pb2' + # @@protoc_insertion_point(class_scope:asrt.SpeechResponse) + }) +_sym_db.RegisterMessage(SpeechResponse) + +LanguageRequest = _reflection.GeneratedProtocolMessageType('LanguageRequest', (_message.Message,), { + 'DESCRIPTOR' : _LANGUAGEREQUEST, + '__module__' : 'asrt_pb2' + # @@protoc_insertion_point(class_scope:asrt.LanguageRequest) + }) +_sym_db.RegisterMessage(LanguageRequest) + +TextResponse = _reflection.GeneratedProtocolMessageType('TextResponse', (_message.Message,), { + 'DESCRIPTOR' : _TEXTRESPONSE, + '__module__' : 'asrt_pb2' + # @@protoc_insertion_point(class_scope:asrt.TextResponse) + }) +_sym_db.RegisterMessage(TextResponse) + +WavData = _reflection.GeneratedProtocolMessageType('WavData', (_message.Message,), { + 'DESCRIPTOR' : _WAVDATA, + '__module__' : 'asrt_pb2' + # @@protoc_insertion_point(class_scope:asrt.WavData) + }) +_sym_db.RegisterMessage(WavData) + + + +_ASRTGRPCSERVICE = _descriptor.ServiceDescriptor( + name='AsrtGrpcService', + full_name='asrt.AsrtGrpcService', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=360, + serialized_end=596, + methods=[ + _descriptor.MethodDescriptor( + name='Speech', + full_name='asrt.AsrtGrpcService.Speech', + index=0, + containing_service=None, + input_type=_SPEECHREQUEST, + output_type=_SPEECHRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='Language', + full_name='asrt.AsrtGrpcService.Language', + index=1, + containing_service=None, + input_type=_LANGUAGEREQUEST, + output_type=_TEXTRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='All', + full_name='asrt.AsrtGrpcService.All', + index=2, + containing_service=None, + input_type=_SPEECHREQUEST, + output_type=_TEXTRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='Stream', + full_name='asrt.AsrtGrpcService.Stream', + index=3, + containing_service=None, + input_type=_SPEECHREQUEST, + output_type=_TEXTRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_ASRTGRPCSERVICE) + +DESCRIPTOR.services_by_name['AsrtGrpcService'] = _ASRTGRPCSERVICE + +# @@protoc_insertion_point(module_scope) diff --git a/assets/asrt_pb2_grpc.py b/assets/asrt_pb2_grpc.py new file mode 100644 index 0000000..85b670b --- /dev/null +++ b/assets/asrt_pb2_grpc.py @@ -0,0 +1,168 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import assets.asrt_pb2 as asrt__pb2 + + +class AsrtGrpcServiceStub(object): + """定义服务接口 + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Speech = channel.unary_unary( + '/asrt.AsrtGrpcService/Speech', + request_serializer=asrt__pb2.SpeechRequest.SerializeToString, + response_deserializer=asrt__pb2.SpeechResponse.FromString, + ) + self.Language = channel.unary_unary( + '/asrt.AsrtGrpcService/Language', + request_serializer=asrt__pb2.LanguageRequest.SerializeToString, + response_deserializer=asrt__pb2.TextResponse.FromString, + ) + self.All = channel.unary_unary( + '/asrt.AsrtGrpcService/All', + request_serializer=asrt__pb2.SpeechRequest.SerializeToString, + response_deserializer=asrt__pb2.TextResponse.FromString, + ) + self.Stream = channel.stream_stream( + '/asrt.AsrtGrpcService/Stream', + request_serializer=asrt__pb2.SpeechRequest.SerializeToString, + response_deserializer=asrt__pb2.TextResponse.FromString, + ) + + +class AsrtGrpcServiceServicer(object): + """定义服务接口 + """ + + def Speech(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Language(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def All(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Stream(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsrtGrpcServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Speech': grpc.unary_unary_rpc_method_handler( + servicer.Speech, + request_deserializer=asrt__pb2.SpeechRequest.FromString, + response_serializer=asrt__pb2.SpeechResponse.SerializeToString, + ), + 'Language': grpc.unary_unary_rpc_method_handler( + servicer.Language, + request_deserializer=asrt__pb2.LanguageRequest.FromString, + response_serializer=asrt__pb2.TextResponse.SerializeToString, + ), + 'All': grpc.unary_unary_rpc_method_handler( + servicer.All, + request_deserializer=asrt__pb2.SpeechRequest.FromString, + response_serializer=asrt__pb2.TextResponse.SerializeToString, + ), + 'Stream': grpc.stream_stream_rpc_method_handler( + servicer.Stream, + request_deserializer=asrt__pb2.SpeechRequest.FromString, + response_serializer=asrt__pb2.TextResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'asrt.AsrtGrpcService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class AsrtGrpcService(object): + """定义服务接口 + """ + + @staticmethod + def Speech(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Speech', + asrt__pb2.SpeechRequest.SerializeToString, + asrt__pb2.SpeechResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Language(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Language', + asrt__pb2.LanguageRequest.SerializeToString, + asrt__pb2.TextResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def All(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/All', + asrt__pb2.SpeechRequest.SerializeToString, + asrt__pb2.TextResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Stream(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream(request_iterator, target, '/asrt.AsrtGrpcService/Stream', + asrt__pb2.SpeechRequest.SerializeToString, + asrt__pb2.TextResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/client_grpc.py b/client_grpc.py new file mode 100644 index 0000000..1d4de55 --- /dev/null +++ b/client_grpc.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +''' +@author: nl8590687 +ASRT语音识别asrserver grpc协议测试专用客户端 +''' + +import time +import grpc + +from assets.asrt_pb2_grpc import AsrtGrpcServiceStub +from assets.asrt_pb2 import SpeechRequest, LanguageRequest, WavData + +from utils.ops import read_wav_bytes + +def run_speech(): + ''' + 请求ASRT服务Speech方法 + :return: + ''' + conn=grpc.insecure_channel('127.0.0.1:20002') + client = AsrtGrpcServiceStub(channel=conn) + + wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav') + print('sample_width:', sample_width) + wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate, + channels=channels, byte_width=sample_width) + + request = SpeechRequest(wav_data=wav_data) + time_stamp0=time.time() + response = client.Speech(request) + time_stamp1 = time.time() + print('time:', time_stamp1-time_stamp0, 's') + print("received:",response.result_data) + +def run_lan(): + ''' + 请求ASRT服务Language方法 + :return: + ''' + conn=grpc.insecure_channel('127.0.0.1:20002') + client = AsrtGrpcServiceStub(channel=conn) + pinyin_data = ['ni3', 'hao3', 'ya5'] + request = LanguageRequest(pinyins=pinyin_data) + time_stamp0=time.time() + response = client.Language(request) + time_stamp1 = time.time() + print('time:', time_stamp1-time_stamp0, 's') + print("received:",response.text_result) + +def run_all(): + ''' + 请求ASRT服务All方法 + :return: + ''' + conn=grpc.insecure_channel('127.0.0.1:20002') + client = AsrtGrpcServiceStub(channel=conn) + + wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav') + print('sample_width:', sample_width) + wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate, + channels=channels, byte_width=sample_width) + + request = SpeechRequest(wav_data=wav_data) + time_stamp0=time.time() + response = client.All(request) + time_stamp1=time.time() + print("received:",response.text_result) + print('time:', time_stamp1-time_stamp0, 's') + +def run_stream(): + ''' + 请求ASRT服务Stream方法 + :return: + ''' + conn=grpc.insecure_channel('127.0.0.1:20002') + client = AsrtGrpcServiceStub(channel=conn) + + wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('assets/A11_0.wav') + print('sample_width:', sample_width) + wav_data = WavData(samples=wav_bytes, sample_rate=sample_rate, + channels=channels, byte_width=sample_width) + + # 先制造一些客户端能发送的数据 + def make_some_data(): + for _ in range(1): + time.sleep(1) + yield SpeechRequest(wav_data=wav_data) + + try: + status_response = client.Stream(make_some_data()) + for ret in status_response: + print("received:", ret.text_result, " , status:", ret.status_code) + time.sleep(0.1) + except Exception as any_exception: + print(f'err in send_status:{any_exception}') + return + + +if __name__ == '__main__': + #run_all() + run_stream()