115 lines
4.8 KiB
Python
115 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright 2016-2099 Ailemon.net
|
|
#
|
|
# This file is part of ASRT Speech Recognition Tool.
|
|
#
|
|
# ASRT is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
# ASRT is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
|
# ============================================================================
|
|
|
|
'''
|
|
感谢原作者的无私奉献
|
|
来自:
|
|
https://www.jianshu.com/p/db0ba022936f
|
|
'''
|
|
|
|
import tensorflow as tf
|
|
import tensorflow.keras as kr
|
|
import tensorflow.keras.backend as K
|
|
import tensorflow.keras.layers as KL
|
|
|
|
class ParallelModel(keras.models.Model):
|
|
"""Subclasses the standard Keras Model and adds multi-GPU support.
|
|
It works by creating a copy of the model on each GPU. Then it slices
|
|
the inputs and sends a slice to each copy of the model, and then
|
|
merges the outputs together and applies the loss on the combined
|
|
outputs.
|
|
"""
|
|
|
|
def __init__(self, keras_model, gpu_count):
|
|
"""Class constructor.
|
|
keras_model: The Keras model to parallelize
|
|
gpu_count: Number of GPUs. Must be > 1
|
|
"""
|
|
super(ParallelModel, self).__init__() # Thanks to @greatken999 for fixing bugs
|
|
self.inner_model = keras_model
|
|
self.gpu_count = gpu_count
|
|
merged_outputs = self.make_parallel()
|
|
super(ParallelModel, self).__init__(inputs=self.inner_model.inputs,
|
|
outputs=merged_outputs)
|
|
|
|
def __getattribute__(self, attrname):
|
|
"""Redirect loading and saving methods to the inner model. That's where
|
|
the weights are stored."""
|
|
if 'load' in attrname or 'save' in attrname:
|
|
return getattr(self.inner_model, attrname)
|
|
return super(ParallelModel, self).__getattribute__(attrname)
|
|
|
|
def summary(self, *args, **kwargs):
|
|
"""Override summary() to display summaries of both, the wrapper
|
|
and inner models."""
|
|
super(ParallelModel, self).summary(*args, **kwargs)
|
|
self.inner_model.summary(*args, **kwargs)
|
|
|
|
def make_parallel(self):
|
|
"""Creates a new wrapper model that consists of multiple replicas of
|
|
the original model placed on different GPUs.
|
|
"""
|
|
# Slice inputs. Slice inputs on the CPU to avoid sending a copy
|
|
# of the full inputs to all GPUs. Saves on bandwidth and memory.
|
|
input_slices = {name: tf.split(x, self.gpu_count)
|
|
for name, x in zip(self.inner_model.input_names,
|
|
self.inner_model.inputs)}
|
|
|
|
output_names = self.inner_model.output_names
|
|
outputs_all = []
|
|
for i in range(len(self.inner_model.outputs)):
|
|
outputs_all.append([])
|
|
|
|
# Run the model call() on each GPU to place the ops there
|
|
for i in range(self.gpu_count):
|
|
with tf.device('/gpu:%d' % i):
|
|
with tf.name_scope('tower_%d' % i):
|
|
# Run a slice of inputs through this replica
|
|
zipped_inputs = zip(self.inner_model.input_names,
|
|
self.inner_model.inputs)
|
|
inputs = [
|
|
KL.Lambda(lambda s: input_slices[name][i],
|
|
output_shape=lambda s: (None,) + s[1:])(tensor)
|
|
for name, tensor in zipped_inputs]
|
|
# Create the model replica and get the outputs
|
|
outputs = self.inner_model(inputs)
|
|
if not isinstance(outputs, list):
|
|
outputs = [outputs]
|
|
# Save the outputs for merging back together later
|
|
for l, o in enumerate(outputs):
|
|
outputs_all[l].append(o)
|
|
|
|
# Merge outputs on CPU
|
|
with tf.device('/cpu:0'):
|
|
merged = []
|
|
for outputs, name in zip(outputs_all, output_names):
|
|
# If outputs are numbers without dimensions, add a batch dim.
|
|
def add_dim(tensor):
|
|
"""Add a dimension to tensors that don't have any."""
|
|
if K.int_shape(tensor) == ():
|
|
return KL.Lambda(lambda t: K.reshape(t, [1, 1]))(tensor)
|
|
return tensor
|
|
outputs = list(map(add_dim, outputs))
|
|
|
|
# Concatenate
|
|
merged.append(KL.Concatenate(axis=0, name=name)(outputs))
|
|
return merged
|
|
|