简介
本篇博客讲解tensorflow如何将数据灌进模型进行训练
tensorflow有两种将数据灌进tensorflow计算图的方法
- 使用
tf.placeholder - 使用
tf.Data
tf.placeholder
占位符适用于简单的实验
参考:
tf.data
tf.data模块是tensorflow将数据流式传输到模型的首选方法。也是本篇博客的重点
参考:
在该模块中TensorFlow引入了两个新的抽象类
tf.data.Dataset表示一系列元素,其中每个元素包含一个或多个Tensor对象。例如,在图像管道中,元素可能是单个训练样本,具有一对表示图像数据和标签的张量。tf.data.Iterator提供了从数据集中提取元素的主要方法。
tf.data.Dataset
tf.data.Dataset代表了一系列元素,每个元素包含一个或者多个Tensor对象。例如,在图像管道中,元素可能是单个训练样本,具有一对表示图像数据和标签的张量。可以通过两种不同的方式来创建数据集:
- 创建来源(例如
Dataset.from_tensor_slices()),以通过一个或多个tf.Tensor对象构建数据集。 - 应用转换(例如
Dataset.batch()),以通过一个或多个tf.data.Dataset对象构建数据集。
创建数据集
下面将详细讲解如何创建数据集
创建来源
tf.data.Dataset提供了四个静态方法来创建数据集
tf.data.Dataset.from_tensors
函数签名为def from_tensors(tensors) -> TensorDataset
使用方法差不多同下面的tf.data.Dataset.from_tensor_slices
tf.data.Dataset.from_tensor_slices
函数签名为def from_tensor_slices(tensors) -> TensorSliceDataset
代码使用样例如下:
1 | dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) |
和tf.data.Dataset.from_tensors不同在于,这种方法接受tuple和dict作为输入。每一个tensor的第0维的size必须一致。
tf.data.Dataset.from_generator
函数签名为def from_generator(generator, output_types, output_shapes=None, args=None) -> Dataset
代码样例如下:1
2
3
4
5
6
7
8
9
10
11
12import itertools
def gen():
for i in itertools.count(1):
yield (i, [1]*i)
ds = tf.data.Dataset.from_generator(gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
value = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
print(sess.run(value)) # (1, array([1]))
print(sess.run(value)) # (2, array([1, 1]))
详细使用方法见:
tf.data.Dataset.range
函数签名为def range(*args) -> Dataset
使用方法和python内置方法range差不多。
可参见文档:
应用转换
此部分参考后面的tf.data.Dataset.map(f)部分
tf.data.Iterator
tf.data.Iterator提供了从 Dataset中提取元素的主要方法。Iterator.get_next() 返回的操作会在执行时生成 Dataset 的下一个元素,并且此操作通常充当输入管道代码和模型之间的接口。
创建迭代器
构建了表示输入数据的 Dataset 后,下一步就是创建 Iterator 来访问该数据集中的元素。tf.data API 目前支持下列迭代器,复杂程度逐渐增大:
- 单次
- 可初始化
- 可重新初始化
- 可馈送
单次
单次迭代器仅仅对数据集进行一次迭代。不需要显示初始化
展示代码如下:1
2
3
4
5dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()for i in range(100):
value = sess.run(next_element)
assert i == value
单词迭代器虽然只可以迭代一次,但是依旧可以在Dataset上对数据集进行操作,使得可以对数据集进行多轮训练。比如我们想在数据集上进行10轮训练,只需要修改上面的代码为:1
2
3
4
5dataset = tf.data.Dataset.range(100).repeat(10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()for i in range(100):
value = sess.run(next_element)
assert i == value
可初始化
您需要先运行显式 iterator.initializer 操作,然后才能使用可初始化迭代器。相比单次迭代器,其每初始化一次就可以迭代一次数据。故其相对于单次迭代器如果需要对数据进行多轮训练只能在Dataset上面做手脚不同,可初始化迭代器可以对迭代器进行多次初始化从而进行多次迭代。
下面是该迭代器使用方式:
1 | max_value = tf.placeholder(tf.int64, shape=[]) |
可重新初始化
可重新初始化迭代器可以通过多个不同的 Dataset 对象进行初始化。故而其相对于可初始化迭代器,其可以使用一个Dataset进行训练,另外一个Dataset进行验证
使用代码如下:
1 | """可重新初始化迭代器使用方法""" |
运行上面的代码其结果为1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27train 0 0
train 1 1
train 2 2
train 3 3
train 4 4
test 0 0
test 1 1
test 2 2
test 3 3
train 0 0
train 1 1
train 2 2
train 3 3
train 4 4
test 0 0
test 1 1
test 2 2
test 3 3
train 0 0
train 1 1
train 2 2
train 3 3
train 4 4
test 0 0
test 1 1
test 2 2
test 3 3
可以看到上面可重新初始化迭代器可以在多个数据集上切换,即可以做early_stop这样的防止过拟合的手段。
但是其在每一次切换数据的时候,需要进行初始化,然后导致即使之前的数据集没跑完 (比如上面每次在training_dataset上跑,因为每次train的循环取5个数据,但是整个数据集6个数据,还有个数据没用上),也没有机会使用后面的数据了。
通过合适的手段可以使每一次初始化之前,数据跑满,即可实现每次跑完一批数据,然后进行验证,从而实现keras.fit的early_stop功能。代码如下
1 | for i in range(3): |
可馈送
可馈送迭代器可以实现可重新初始化迭代器的所有功能,而且相较后者迭代器之间进行切换时需要从数据集的开头初始化迭代器不同,可馈送迭代器可以从数据集没跑完的那个断点处接着跑。
这样的一个应用场景是,比如说我们在做机器翻译的时候,如果数据集非常大,那么我们可能只需要在数据集上跑五六个epoch即可得到一个很好的模型,这个时候使用可重新初始化迭代器(或者等同于keras),我们很难去做early_stop。原因是如果我们要充分利用所有数据,我们可重新初始化迭代器的验证点只能在所有训练数据跑完一轮才可以,但是我们的模型就跑不了几轮,所以很难去做early_stop。如果我们想做early_stop,那么我们就会损失训练数据。而可馈送迭代器就是解决这个问题的,其验证点可以顺便放在训练数据的哪个点都可以。
如下是可馈送迭代器的使用代码示例:
1 | """可馈送迭代器使用方法""" |
上面代码的输出结果为:
1 | train 0 0 |
保存迭代器状态
有时候我们不能从一开始就可以训练好模型(比如资源比较紧迫的实验室)。这个时候就需要将迭代器的状态也给存储下来,以便从数据的某个点继续训练。
这部分可以详细看https://www.tensorflow.org/guide/datasets#saving_iterator_state
读取数据
tf.data可以从四种格式里面读取数据
- Numpy数组
- TFRecord数据
- 文本数据
- CSV数据
详情可以看 https://www.tensorflow.org/guide/datasets#reading_input_data
tf.data.Dataset.map()预处理数据
tf.data.Dataset.map(f)转换通过将指定函数 f 应用于输入数据集的每个元素来生成新数据集。同时因为tensorflow使用静态图的缘故,很难使用外部python库去处理一些逻辑,但是在解析数据的时候,调用外部python库很有用,tensorflow通过提供tf.py_func()来将外部库的逻辑纳入tensorflow处理数据里面
这块可以参考: https://www.tensorflow.org/guide/datasets#preprocessing_data_with_datasetmap
批处理处理数据
tf.data.Dataset提供了两个方法来进行batch训练:
Dataset.batch,简单的批处理。适用于每条数据的张量形状一样Dataset.padded_batch,适用于自然语言处理的seq类型的每个样本数据长度不一的数据。
详细的两个方法的使用可以参考: https://www.tensorflow.org/guide/datasets#batching_dataset_elements
训练工作流程
资料 https://www.tensorflow.org/guide/datasets#training_workflows 简要的讲解了tf.data如何将数据灌入模型
api
这部分讲解一些Dataset类的方法的简介说明
zip(datasets)
静态方法。类似于python内置函数zip。转换多个数据集为一个多输入数据集
concatenate(self, dataset)
拼接两个dataset为一个更大的数据集
prefetch(self, buffer_size)
list_files(file_pattern, shuffle=None, seed=None)
静态方法。列出有哪些文件
repeat(self, count=None)
重复数据集一定次数。count也叫epoch
shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None)
cache(self, filename=””)
cache数据集
take(self, count)
取出数据集里面的前count个数据。如果count是-1或者它大于数据集本身的大小就取出所有数据
使用方法如下:
1 | import tensorflow as tf |
输出为:
1 | 0 |
skip(self, count)
将前count个数据舍弃。1
2
3
4
5
6
7
8
9
10
11
12
13
14import tensorflow as tf
dataset = tf.data.Dataset.range(10).take(4).skip(1)
iterator = dataset.make_one_shot_iterator()
next_ele = iterator.get_next()
with tf.Session() as sess:
while 1:
try:
print(sess.run(next_ele))
except tf.errors.OutOfRangeError:
print('iterator over')
break
输出为:
1 | 1 |
shard(self, num_shards, index)
这个方法在分布式训练里面特别有用。作用是将数据分成num_shards份分布在不同机器上训练。详细使用可以看官方文档 https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
flat_map(self, map_func)
将数据集里面的每个元素展平然后拼接。
1 | # NOTE: The following examples use `{ ... }` to represent the |
filter(self, predicate)
可以参考 机器翻译数据集用filter过滤掉src或tgt数据里面长度为0的坏数据
apply(self, transformation_func)
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#apply
window(self, size, shift=None, stride=1, drop_remainder=False)
tf.data.Dataset.range(7).window(2) 产生 [{0, 1}, {2, 3}, {4, 5}, {6}]这样的数据。详情看官方文档
reduce(self, initial_state, reduce_func)
该函数返回一个tensor,而不是一个Dataset类
使用如下:1
2
3
4
5import tensorflow as tf
import numpy as np
t = tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)
with tf.Session() as sess:
print(sess.run(t)) # 10