#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright 2016-2099 Ailemon.net # # This file is part of ASRT Speech Recognition Tool. # # ASRT is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # ASRT is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with ASRT. If not, see . # ============================================================================ ''' 感谢原作者的无私奉献 来自: https://www.jianshu.com/p/db0ba022936f ''' import tensorflow as tf import tensorflow.keras as kr import tensorflow.keras.backend as K import tensorflow.keras.layers as KL class ParallelModel(keras.models.Model): """Subclasses the standard Keras Model and adds multi-GPU support. It works by creating a copy of the model on each GPU. Then it slices the inputs and sends a slice to each copy of the model, and then merges the outputs together and applies the loss on the combined outputs. """ def __init__(self, keras_model, gpu_count): """Class constructor. keras_model: The Keras model to parallelize gpu_count: Number of GPUs. Must be > 1 """ super(ParallelModel, self).__init__() # Thanks to @greatken999 for fixing bugs self.inner_model = keras_model self.gpu_count = gpu_count merged_outputs = self.make_parallel() super(ParallelModel, self).__init__(inputs=self.inner_model.inputs, outputs=merged_outputs) def __getattribute__(self, attrname): """Redirect loading and saving methods to the inner model. That's where the weights are stored.""" if 'load' in attrname or 'save' in attrname: return getattr(self.inner_model, attrname) return super(ParallelModel, self).__getattribute__(attrname) def summary(self, *args, **kwargs): """Override summary() to display summaries of both, the wrapper and inner models.""" super(ParallelModel, self).summary(*args, **kwargs) self.inner_model.summary(*args, **kwargs) def make_parallel(self): """Creates a new wrapper model that consists of multiple replicas of the original model placed on different GPUs. """ # Slice inputs. Slice inputs on the CPU to avoid sending a copy # of the full inputs to all GPUs. Saves on bandwidth and memory. input_slices = {name: tf.split(x, self.gpu_count) for name, x in zip(self.inner_model.input_names, self.inner_model.inputs)} output_names = self.inner_model.output_names outputs_all = [] for i in range(len(self.inner_model.outputs)): outputs_all.append([]) # Run the model call() on each GPU to place the ops there for i in range(self.gpu_count): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % i): # Run a slice of inputs through this replica zipped_inputs = zip(self.inner_model.input_names, self.inner_model.inputs) inputs = [ KL.Lambda(lambda s: input_slices[name][i], output_shape=lambda s: (None,) + s[1:])(tensor) for name, tensor in zipped_inputs] # Create the model replica and get the outputs outputs = self.inner_model(inputs) if not isinstance(outputs, list): outputs = [outputs] # Save the outputs for merging back together later for l, o in enumerate(outputs): outputs_all[l].append(o) # Merge outputs on CPU with tf.device('/cpu:0'): merged = [] for outputs, name in zip(outputs_all, output_names): # If outputs are numbers without dimensions, add a batch dim. def add_dim(tensor): """Add a dimension to tensors that don't have any.""" if K.int_shape(tensor) == (): return KL.Lambda(lambda t: K.reshape(t, [1, 1]))(tensor) return tensor outputs = list(map(add_dim, outputs)) # Concatenate merged.append(KL.Concatenate(axis=0, name=name)(outputs)) return merged