tensorflow_hub

tensorflow hub是用来加载预训练模型的,比如说bert模型什么的。
本篇博客讲述如何训练导出一个tensorflow hub 所拥有的模型。然后如何加载这么一个模型

如下是训练导出一个tensorflow hub模块所接受的模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

import tensorflow as tf
import tensorflow_hub as hub


'''2 - 建立一个网络结构,并基于该网络结构建立一个Module '''
def half_plus_two():
'''该函数主要是创建一个简单的模型,其网络结构就是y = a*x + b '''
# 创建两个变量,a和b,如网络的权重和偏置
a = tf.get_variable('a', shape=[])
b = tf.get_variable('b', shape=[])
# 创建一个占位变量,为后面graph的输入提供准备
x = tf.placeholder(tf.float32)
# 创建一个完整的graph
y = a*x + b
# 通过hub的add_signature,建立hub需要的网络
# 这个函数决定了如何去加载这个模型。
hub.add_signature(name='info', inputs={'x': x}, outputs={'y': y})


def export_module(path):
'''该函数用于调用创建api进行module创建,然后进行网络的权重赋值,最后通过session进行运行权重初始化,并最后输出该module'''
# 通过hub的create_module_spec,接收函数建立一个Module
spec = hub.create_module_spec(half_plus_two)
# 防止串graph,将当期的操作放入同一个graph中
with tf.Graph().as_default():
# 通过hub的Module读取一个模块,该模块可以是url链接,表示从tensorflow hub去拉取,
# 或者接收上述创建好的module
module = hub.Module(spec)
# 这里演示如何将权重值赋予到graph中的变量,如从checkpoint中进行变量恢复等
init_a = tf.assign(module.variable_map['a'], 0.5)
init_b = tf.assign(module.variable_map['b'], 2.0)
init_vars = tf.group([init_a, init_b])

with tf.Session() as sess:
# 运行初始化,为了将其中变量的值设置为赋予的值
sess.run(init_vars)
# 将模型导出到指定路径
module.export(path, sess)


def main(argv):
export_module('hub_dict')

if __name__ == '__main__':
tf.app.run(main)

加载上面代码所创建的模型的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import subprocess

import tensorflow as tf
import tensorflow_hub as hub


class HalfPlusTwoTests(tf.test.TestCase):

def testExportTool(self):

# 指定module的文件夹位置,这里是export
module_path = os.path.join('.', 'hub_dict')


with tf.Graph().as_default():
# 读取当前存在的一个module
m = hub.Module(module_path)
# 如直接采用y=f(x) 一样进行调用,
# 这个地方调用hub.Module.__call__函数,这个函数比较重要的参数是inputs, signature, as_dict
# 这个地方如何使用取决于hub.add_signature如何构建网络。
# add_signature的name控制hub.Module.__call__的signature参数。如果name为None的话,则name默认设置为default,hub.Module.__call__的signature参数不用传参,如果传参必须为default。如果name不是默认参数的话,那么signature必须同name是同一个字符串。
# add_signature的inputs控制hub.Module.__call__的inputs, 如果add_signature的inputs为多输入的话(dict),则hub.Module.__call__的inputs必须为dict, 且key必须相同, value同add_signature的inputs的value所定义的placeholder相同
# add_signature的outputs控制hub.Module.__call__的as_dict, 如果outputs为多个输出的话(必须为dict形式),则as_dict必须为True。call函数返回一个dict, 每一项的key对应于outputs里面的每一项
output = m({'x': [10, 3, 4]}, signature='info', as_dict=True)

print(output)

with tf.Session() as sess:
# 惯例进行全局变量初始化
sess.run(tf.initializers.global_variables())
# 观察生成的值是否与预定义值一致,即prediction是否与label一致
self.assertAllEqual(sess.run(output['y']), [7, 3.5, 4])

if __name__ == '__main__':
tf.test.main()

官方讲解tensorflowHub的文档