博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tf.train.batch的偶尔乱序问题
阅读量:7089 次
发布时间:2019-06-28

本文共 2670 字,大约阅读时间需要 8 分钟。

tf.train.batch的偶尔乱序问题

觉得有用的话,欢迎一起讨论相互学习~

tf.train.batch的偶尔乱序问题

  • 我们在通过tf.Reader读取文件后,都需要用batch函数将读取的数据根据预先设定的batch_size打包为一个个独立的batch方便我们进行学习。
  • 常用的batch函数有tf.train.batch和tf.train.shuffle_batch函数。前者是将数据从前往后读取并顺序打包,后者则要进行乱序处理————即将读取的数据进行乱序后在组成批次。
  • 训练时我往往都是使用shuffle_batch函数,但是这次我在验证集上预调好模型并freeze模型后我需要在测试集上进行测试。此时我需要将数据的标签和inference后的结果进行一一对应。 此时数据出现的顺序是十分重要的,这保证我们的产品在上线前的测试集中能准确get到每个数据和inference后结果的差距 而在验证集中我们不太关心数据原有的标签和inference后的真实值,我们往往只是需要让这两个数据一一对应,关于数据出现的顺序我们并不关心。
  • 此时我们一般使用tf.train.batch函数将tf.Reader读取的值进行顺序打包即可。

    然而tf.train.batch函数往往会有偶尔乱序的情况

  • 我们将csv文件中每个数据样本从上往下依次进行标号,我们在使用tf.trian.batch函数依次进行读取,如果我们读取的数据编号乱序了,则表明tf.train.batch函数有偶尔乱序的状况。

import tensorflow as tfBATCH_SIZE = 400NUM_THREADS = 2MAX_NUM = 500def read_data(file_queue):    reader = tf.TextLineReader(skip_header_lines=1)    key, value = reader.read(file_queue)    defaults = [[0], [0.], [0.]]    NUM, C, Tensile = tf.decode_csv(value, defaults)    vertor_example = tf.stack([C])    vertor_label = tf.stack([Tensile])    vertor_num = tf.stack([NUM])    return vertor_example, vertor_label, vertor_numdef create_pipeline(filename, batch_size, num_threads):    file_queue = tf.train.string_input_producer([filename])  # 设置文件名队列    example, label, no = read_data(file_queue)  # 读取数据和标签    example_batch, label_batch, no_batch = tf.train.batch(        [example, label, no], batch_size=batch_size, num_threads=num_threads, capacity=MAX_NUM)    return example_batch, label_batch, no_batchx_train_batch, y_train_batch, no_train_batch = create_pipeline('test_tf_train_batch.csv', batch_size=BATCH_SIZE,                                                               num_threads=NUM_THREADS)init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer()with tf.Session() as sess:    sess.run(local_init_op)    sess.run(init_op)    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    example, label, num = sess.run([x_train_batch, y_train_batch, no_train_batch])    print(example)    print(label)    print(num)    coord.request_stop()    coord.join(threads)

实验结果

我们将csv文件中的真实Tensile值放在第一列,将使用tf.train.batch函数得到的Tensile和no分别放在第二列和第三列

TureTensile | FalseTensile | NO
:- | :-: | -:
0.830357143 | [ 0.52678573] | [ 66]
0.526785714 | [ 0.83035713] | [ 65]
0.553571429 | [ 0.4375 ] | [ 68]
0.4375 | [ 0.5535714 ] | [ 67]
0.517857143 | [ 0.33035713] | [ 70]
0.330357143 | [ 0.51785713] | [ 69]
0.482142857 | [ 0.6785714 ] | [ 72]
0.678571429 | [ 0.48214287] | [ 71]
0.419642857 | [ 0.02678571] | [ 74]
0.026785714 | [ 0.41964287] | [ 73]
0.401785714 | [ 0.4017857 ] | [ 75]

解决方案

  • 将测试集中所有样本数据加NO顺序标签列

转载于:https://www.cnblogs.com/cloud-ken/p/9092010.html

你可能感兴趣的文章
linux的yum仓库配置
查看>>
XSUPERSMS COME ON
查看>>
[JS2] JS是弱类型
查看>>
企业搜索引擎开发之连接器connector(二十四)
查看>>
数学图形(1.9)悬链线
查看>>
有上下界的网络流问题
查看>>
AspectJ获取方法注解的信息
查看>>
HDU 4902 Nice boat(线段树)
查看>>
Codeforces Round #114 (Div. 1) E. Wizards and Bets 高斯消元
查看>>
怎样调通微信支付及微信发货通知接口(Js API)
查看>>
Android 属性动画(Property Animation) 全然解析 (下)
查看>>
推断汉字正則表達式更严谨方法!
查看>>
如何避免误删CleanMyMac语言文件
查看>>
Linux下免安装mysql
查看>>
快钱报错:javax.net.ssl.SSLProtocolException: handshake alert: unrecognized_name解决
查看>>
Hadoop集群WordCount运行详解(转)
查看>>
[转]SSM框架——详细整合教程(Spring+SpringMVC+MyBatis)
查看>>
一次性搞清楚equals和hashCode
查看>>
Android Studio IDE的 LogCat如何过滤指定应用的调试信息
查看>>
23个常用正则表达式(数值和字符串)
查看>>