diff --git a/test_mspeech.py b/test_mspeech.py index 563741b..1cc4358 100644 --- a/test_mspeech.py +++ b/test_mspeech.py @@ -17,10 +17,12 @@ from SpeechModel251 import ModelSpeech os.environ["CUDA_VISIBLE_DEVICES"] = "0" #进行配置,使用90%的GPU -config = tf.ConfigProto() +config = tf.compat.v1.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.9 #config.gpu_options.allow_growth=True #不全部占满显存, 按需分配 -set_session(tf.Session(config=config)) +sess = tf.compat.v1.Session(config=config) +tf.compat.v1.keras.backend.set_session(sess) + datapath = '' diff --git a/train_mspeech.py b/train_mspeech.py index 4748f2e..2d99c6d 100644 --- a/train_mspeech.py +++ b/train_mspeech.py @@ -16,10 +16,11 @@ from SpeechModel251 import ModelSpeech os.environ["CUDA_VISIBLE_DEVICES"] = "0" #进行配置,使用95%的GPU -config = tf.ConfigProto() +config = tf.compat.v1.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.95 #config.gpu_options.allow_growth=True #不全部占满显存, 按需分配 -set_session(tf.Session(config=config)) +sess = tf.compat.v1.Session(config=config) +tf.compat.v1.keras.backend.set_session(sess) datapath = ''