python的tf.train.batch函数怎么用
这篇文章主要介绍"python的tf.train.batch函数怎么用"的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇"python的tf.train.batch函数怎么用"文章能帮助大家解决问题。
tf.train.batch函数
tf.train.batch( tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None)
其中:
1、tensors:利用slice_input_producer获得的数据组合。
2、batch_size:设置每次从队列中获取出队数据的数量。
3、num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。
4、capacity:一个整数,用来设置队列中元素的最大数量
5、allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。
6、name:名字
测试代码
1、allow_samller_final_batch=True
import pandas as pdimport numpy as npimport tensorflow as tf# 生成数据def generate_data(): num = 18 label = np.arange(num) return label# 获取数据def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True) return label_batch# 数据组label = get_batch_data()sess = tf.Session()# 初始化变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 初始化batch训练的参数coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try: while not coord.should_stop(): # 自动获取下一组数据 l = sess.run(label) print(l)except tf.errors.OutOfRangeError: print('Done training')finally: coord.request_stop()coord.join(threads)sess.close()
运行结果为:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
[17]
Done training
2、allow_samller_final_batch=False
相比allow_samller_final_batch=True,输出结果少了[17]
import pandas as pdimport numpy as npimport tensorflow as tf# 生成数据def generate_data(): num = 18 label = np.arange(num) return label# 获取数据def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return label_batch# 数据组label = get_batch_data()sess = tf.Session()# 初始化变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 初始化batch训练的参数coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try: while not coord.should_stop(): # 自动获取下一组数据 l = sess.run(label) print(l)except tf.errors.OutOfRangeError: print('Done training')finally: coord.request_stop()coord.join(threads)sess.close()
运行结果为:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
Done training
关于"python的tf.train.batch函数怎么用"的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注行业资讯频道,小编每天都会为大家更新不同的知识点。