diff --git a/speech_model.py b/speech_model.py index e51628d..f7c1cd6 100644 --- a/speech_model.py +++ b/speech_model.py @@ -29,6 +29,7 @@ import numpy as np from utils.ops import get_edit_distance, read_wav_data from utils.config import load_config_file, DEFAULT_CONFIG_FILENAME, load_pinyin_dict +from utils.thread import threadsafe_generator class ModelSpeech: ''' @@ -45,6 +46,7 @@ class ModelSpeech: self.speech_features = speech_features self.max_label_length = max_label_length + @threadsafe_generator def _data_generator(self, batch_size, data_loader): ''' 数据生成器函数,用于Keras的generator_fit训练 diff --git a/utils/thread.py b/utils/thread.py new file mode 100644 index 0000000..dcc1033 --- /dev/null +++ b/utils/thread.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +''' +@author: nl8590687 +解决python生成器多线程的线程安全问题 +''' + +import threading +''' + A generic iterator and generator that takes any iterator and wrap it to make it thread safe. + This method was introducted by Anand Chitipothu in http://anandology.com/blog/using-iterators-and-generators/ + but was not compatible with python 3. This modified version is now compatible and works both in python 2.8 and 3.0 +''' +class threadsafe_iter: + """Takes an iterator/generator and makes it thread-safe by + serializing call to the `next` method of given iterator/generator. + """ + def __init__(self, it): + self.it = it + self.lock = threading.Lock() + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + return self.it.__next__() + +def threadsafe_generator(f): + """A decorator that takes a generator function and makes it thread-safe. + """ + def g(*a, **kw): + return threadsafe_iter(f(*a, **kw)) + return g