estimator
estimator同keras是tensorflow的高级API。在tensorflow1.13以上,estimator已经作为一个单独的package从tensorflow分离出来了。
estimator抽象了tensorflow底层的api, 同keras一样,他分离了model和data, 不同于keras这个不得不认养的儿子,estimator作为tensorflow的亲儿子,天生具有分布式的基因,更容易在生产环境里面使用
tensorflow官方文档提供了比较详细的estimator程序的构建过程:
https://www.tensorflow.org/guide#estimators
tensorflow model提供了estimator构建的mnist程序:
https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py
estimator模型由model_fn决定:
官方文档:
其中features, labels是必需的。model, params, config参数是可选的
如下是estiamtor定义的一个模型:
1 | def my_model(features, labels, mode, params): |
可以看见features, labels分别是模型的input,output。所以是必须的。但是在有的模型里面。比如bert预训练的模型里面,我们不需要训练,所以只用到features。