tf.estimator

estimator

estimator同keras是tensorflow的高级API。在tensorflow1.13以上,estimator已经作为一个单独的package从tensorflow分离出来了。
estimator抽象了tensorflow底层的api, 同keras一样,他分离了model和data, 不同于keras这个不得不认养的儿子,estimator作为tensorflow的亲儿子,天生具有分布式的基因,更容易在生产环境里面使用

tensorflow官方文档提供了比较详细的estimator程序的构建过程:
https://www.tensorflow.org/guide#estimators

tensorflow model提供了estimator构建的mnist程序:
https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

estimator模型由model_fn决定:
官方文档:

其中features, labels是必需的。model, params, config参数是可选的
如下是estiamtor定义的一个模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def my_model(features, labels, mode, params):
"""DNN with three hidden layers and learning_rate=0.1."""
# Create three fully connected layers.
net = tf.feature_column.input_layer(features, params['feature_columns'])
for units in params['hidden_units']:
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)

# Compute logits (1 per class).
logits = tf.layers.dense(net, params['n_classes'], activation=None)

# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'class_ids': predicted_classes[:, tf.newaxis],
'probabilities': tf.nn.softmax(logits),
'logits': logits,
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(labels=labels,
predictions=predicted_classes,
name='acc_op')
metrics = {'accuracy': accuracy}
tf.summary.scalar('accuracy', accuracy[1])

if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode, loss=loss, eval_metric_ops=metrics)

# Create training op.
assert mode == tf.estimator.ModeKeys.TRAIN

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

可以看见features, labels分别是模型的input,output。所以是必须的。但是在有的模型里面。比如bert预训练的模型里面,我们不需要训练,所以只用到features。