千家信息网

keras的get_value运行越来越慢如何解决

发表于:2025-02-16 作者:千家信息网编辑
千家信息网最后更新 2025年02月16日,这篇文章主要介绍"keras的get_value运行越来越慢如何解决"的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇"keras的get_value运行越来越慢如
千家信息网最后更新 2025年02月16日keras的get_value运行越来越慢如何解决

这篇文章主要介绍"keras的get_value运行越来越慢如何解决"的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇"keras的get_value运行越来越慢如何解决"文章能帮助大家解决问题。

问题描述


如上图所示,经过时间和内存消耗跟踪测试,发现是keras.backend.get_value() 函数导致的程序越来越慢,而且严重的造成内存泄露;

查看该函数内部实现,发现一个主要核心是x.eval(session=get_session()),该语句可能是导致内存泄露和运行慢的核心语句; 根据查看一些博文得到了运行得越来越慢的

原因该x.eval函数会添加新的节点到tf的图中;而这也导致了tf的图越来越大,内存泄露;

解决方法

import tensorflow.keras.backend as Kdef get_my_session(gpu_fraction=0.1):    '''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''    num_threads = os.environ.get('OMP_NUM_THREADS')    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)    if num_threads:        return tf.Session(config=tf.ConfigProto(            gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))    else:        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))K.set_session(get_my_session())

如上图所示, 我在使用tensorflow之前(也就是该工程文件前面),对session进行自定义,然后用自定义的session设定keras.backend.set_session();

然后删除get_value() 函数,直接用get_value()中所使用的执行语句x.eval(session=get_my_session());这样这个添加节点导致内存泄露的核心语句x.eval()就使用的是该工程统一自定义session,然后用tf.reset_default_graph() 对图重置就可以了

即上图问题代码修改为:

output = ctc_decode(y_pred,input_length=input_length,)output = output[0][0]out = output.eval(session=get_my_session())# 删除 K.get_value(out[0][0])tf.reset_default_graph() # 然后重置tf图,这句很关键

这样就解决了get_value()导致的越来越慢的问题;

个人认为:这样可能就不会总是添加新的节点,导致tf图不断地无限变大;而是重复使用这一个自定义的节点。

补充:tensorflow与keras之间版本问题引起get_session问题解决办法

1.产生报错原因

import tensorflow.keras.backend as Kdef __init__(self, **kwargs):    self.__dict__.update(self._defaults) # set up default values    self.__dict__.update(kwargs) # and update with user overrides    self.class_names = self._get_class()    self.anchors = self._get_anchors()    self.sess = K.get_session()

报错如下:

get_session is not available when using TensorFlow 2.0.

意思是 tf2.0 没有 get_session

2.解决方案1

import tensorflow.python.keras.backend as Ksess = K.get_session()

3. 解决方案2

import tensorflow as tfsess = tf.compat.v1.keras.backend.get_session()

之前一直采用方案1 解决,感觉比较方便;但是解决方案1 有其它属性会丢失问题

比如AttributeError: module 'keras.backend' has no attribute image_dim_ordering

所以建议大家采用方案2

关于"keras的get_value运行越来越慢如何解决"的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注行业资讯频道,小编每天都会为大家更新不同的知识点。

0