diff --git a/general_function/muti_gpu.py b/general_function/muti_gpu.py index 621d1f2..6b13577 100644 --- a/general_function/muti_gpu.py +++ b/general_function/muti_gpu.py @@ -24,6 +24,7 @@ class ParallelModel(keras.models.Model): keras_model: The Keras model to parallelize gpu_count: Number of GPUs. Must be > 1 """ + super(ParallelModel, self).__init__() # Thanks to @greatken999 for fixing bugs self.inner_model = keras_model self.gpu_count = gpu_count merged_outputs = self.make_parallel()