博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow 变量共享
阅读量:4325 次
发布时间:2019-06-06

本文共 2219 字,大约阅读时间需要 7 分钟。

tensorflow 变量共享涉及到几个常用的方法,tf.get_variable, tf.variable_scope, tf.reuse_variables等

为了记忆各个方法的功能,与其他方法做一个对比。

tf.variable 与 tf.get_variable

tensorflow中有两种方法生成变量variable, 一种是tf.get_variable(), 另一种是tf.Variable()。

tf.Variable() 在定义name 相同的变量时,为了不重复变量名,会自动给变量赋一个区别于前一个变量的名字(末尾_1等)。因此使用tf.Variable() 定义的变量无法通过name属性获取tf.Variable对象。

tf.get_Variable()的name参数则是唯一的识别标准,因此只要通过tf.get_variables(name='xx')获取得到的变量都是同一个变量。因此更方便于参数共享。而在重复使用时,一定要在代码中强调scope.reuse_variables(),否则系统将会报错。

因此,推荐不论在任何时候创建变量都使用tf.get_variable(),从而可以在任何地方对他进行共享。

tf.name_scope() 与 tf.variable_scope():

tf.name_scope()可以简单的理解为为了更好的管理命名空间的方法。且只会影响tf.Variable()定义的变量的name。

tf.variable_scope()则可以影响到tf.get_variable()创建的对象的name。因此可以与tf.get_variable()一同使用,完成变量共享的目的,同时对命名空间进行管理。

tf.layers 参数复用

例如tf.layers.dense(), tf.layers.conv2D()等,参数复用只需要再tf.layers.dense(x, 4, name='h1', reuse=True),使得参数reuse为True,即可复用上一层的参数。

为了验证,我们可以通过:

x = tf.ones((1, 3))y1 = tf.layers.dense(x, 4, name='h1')y2 = tf.layers.dense(x, 4, name='h1', reuse=True)# y1 and y2 will evaluate to the same valuessess = tf.Session()sess.run(tf.global_variables_initializer())print(sess.run(y1))print(sess.run(y2))  # both prints will return the same values

观察输出变量是否相同的方式来判断两个dense层是否共享参数。

siamese LSTM networks(暹罗LSTM网络结构)

以QA问题、答案对匹配为例:

为实现暹罗神经网络,我们需要使用共享参数的LSTM网络分别对问题和答案提取特征。
不同于tf.nn.dense,由于tf.nn.dynamic_rnn等是基于tf.nn.rnn_cell.LSTMCell() 构造得到的网络结构,因此只需要让dynamic_rnn(cell)中的cell输入为同一个LSTMCell即可:

utterance_gru = tf.nn.rnn_cell.LSTMCell(self.rnn_units, initializer=tf.orthogonal_initializer(),                 state_is_tuple=True,)_, utterance_gru_embeddings = tf.nn.dynamic_rnn(utterance_gru, all_utterance_embeddings,                 sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='utterance_rnn')utterance_gru_embeddings = utterance_gru_embeddings[1]   _, response_gru_embeddings= tf.nn.dynamic_rnn(utterance_gru, response_embeddings,                sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='response_rnn')response_gru_embeddings = response_gru_embeddings[1]self.utt = utterance_gru_embeddings[0]self.res = response_gru_embeddings[0] # if all_utterance_embeddings == response_embeddings, self.utt == self.res

转载于:https://www.cnblogs.com/wuchengze/p/9031736.html

你可能感兴趣的文章
07 js自定义函数
查看>>
jQueru中数据交换格式XML和JSON对比
查看>>
form表单序列化后的数据转json对象
查看>>
[PYTHON]一个简单的单元測试框架
查看>>
iOS开发网络篇—XML数据的解析
查看>>
[BZOJ4303]数列
查看>>
一般处理程序在VS2012中打开问题
查看>>
C语言中的++和--
查看>>
thinkphp3.2.3入口文件详解
查看>>
POJ 1141 Brackets Sequence
查看>>
Ubuntu 18.04 root 使用ssh密钥远程登陆
查看>>
Servlet和JSP的异同。
查看>>
虚拟机centOs Linux与Windows之间的文件传输
查看>>
ethereum(以太坊)(二)--合约中属性和行为的访问权限
查看>>
IOS内存管理
查看>>
middle
查看>>
[Bzoj1009][HNOI2008]GT考试(动态规划)
查看>>
Blob(二进制)、byte[]、long、date之间的类型转换
查看>>
OO第一次总结博客
查看>>
day7
查看>>