在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下:

函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计

使用方法:

tf.contrib.crf.crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params=None)
See the guide: CRF (contrib) Computes the log-likelihood of tag sequences in a CRF. Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the log-likelihood.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available. Returns:
log_likelihood: A scalar containing the log-likelihood of the given sequence of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function.

函数讲解:

1、tf.contrib.crf.crf_log_likelihood

crf_log_likelihood(inputs,tag_indices,sequence_lengths,transition_params=None)

在一个条件随机场里面计算标签序列的log-likelihood

参数:

inputs: 一个形状为[batch_size, max_seq_len, num_tags] 的tensor,一般使用BILSTM处理之后输出转换为他要求的形状作为CRF层的输入.
tag_indices: 一个形状为[batch_size, max_seq_len] 的矩阵,其实就是真实标签.
sequence_lengths: 一个形状为 [batch_size] 的向量,表示每个序列的长度.
transition_params: 形状为[num_tags, num_tags] 的转移矩阵

返回:

log_likelihood: 标量,log-likelihood
transition_params: 形状为[num_tags, num_tags] 的转移矩阵

2、tf.contrib.crf.viterbi_decode

viterbi_decode(score,transition_params) 
通俗一点,作用就是返回最好的标签序列.这个函数只能够在测试时使用,在tensorflow外部解码

参数:

score: 一个形状为[seq_len, num_tags] matrix of unary potentials.
transition_params: 形状为[num_tags, num_tags] 的转移矩阵

返回:

viterbi: 一个形状为[seq_len] 显示了最高分的标签索引的列表.
viterbi_score: A float containing the score for the Viterbi sequence.

3、tf.contrib.crf.crf_decode

crf_decode(potentials,transition_params,sequence_length) 
在tensorflow内解码

参数:

potentials: 一个形状为[batch_size, max_seq_len, num_tags] 的tensor,
transition_params: 一个形状为[num_tags, num_tags] 的转移矩阵
sequence_length: 一个形状为[batch_size] 的 ,表示batch中每个序列的长度

返回:

decode_tags:一个形状为[batch_size, max_seq_len] 的tensor,类型是tf.int32.表示最好的序列标记.
best_score: 有个形状为[batch_size] 的tensor, 包含每个序列解码标签的分数.

转载来自知乎:

如果你需要预测的是个序列,那么可以选择用crf_log_likelihood作为损失函数

crf_log_likelihood(
inputs,
tag_indices,
sequence_lengths,
transition_params=None
)

输入:

inputs:unary potentials,也就是每个标签的预测概率值,这个值根据实际情况选择计算方法,CNN,RNN...都可以

tag_indices,这个就是真实的标签序列了

sequence_lengths,这是一个样本真实的序列长度,因为为了对齐长度会做些padding,但是可以把真实的长度放到这个参数里

transition_params,转移概率,可以没有,没有的话这个函数也会算出来

输出:

log_likelihood,

transition_params,转移概率,如果输入没输,它就自己算个给返回

作者:知乎用户
链接:https://www.zhihu.com/question/57666556/answer/326803900
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

官方的示例代码:如何使用crf来计算:

# !/home/wcg/tools/local/anaconda3/bin/python
# coding=utf8
import numpy as np
import tensorflow as tf #data settings
num_examples = 10
num_words = 20
num_features = 100
num_tags = 5 # 5 tags
#x shape = [10,20,100]
#random features.
x = np.random.rand(num_examples,num_words,num_features).astype(np.float32) #y shape = [10,20] #Random tag indices representing the gold sequence.
y = np.random.randint(num_tags,size = [num_examples,num_words]).astype(np.int32) # 序列的长度
#sequence_lengths = [19,19,19,19,19,19,19,19,19,19]
sequence_lengths = np.full(num_examples,num_words - 1,dtype=np.int32) #Train and evaluate the model.
with tf.Graph().as_default():
with tf.Session() as session:
# Add the data to the TensorFlow gtaph.
x_t = tf.constant(x) #观测序列
y_t = tf.constant(y) # 标记序列
sequence_lengths_t = tf.constant(sequence_lengths) # Compute unary scores from a linear layer.
# weights shape = [100,5]
weights = tf.get_variable("weights", [num_features, num_tags]) # matricized_x_t shape = [200,100]
matricized_x_t = tf.reshape(x_t, [-1, num_features]) # compute [200,100] [100,5] get [200,5]
# 计算结果
matricized_unary_scores = tf.matmul(matricized_x_t, weights) # unary_scores shape = [10,20,5] [10,20,5]
unary_scores = tf.reshape(matricized_unary_scores, [num_examples, num_words, num_tags])
# compute the log-likelihood of the gold sequences and keep the transition
# params for inference at test time.
# shape shape [10,20,5] [10,20] [10]
log_likelihood,transition_params = tf.contrib.crf.crf_log_likelihood(unary_scores,y_t,sequence_lengths_t) viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(unary_scores, transition_params, sequence_lengths_t)
# add a training op to tune the parameters.
loss = tf.reduce_mean(-log_likelihood) # 定义梯度下降算法的优化器
#learning_rate 0.01
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) #train for a fixed number of iterations.
session.run(tf.global_variables_initializer()) '''
#eg:
In [61]: m_20
Out[61]: array([[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]) In [62]: n_20
Out[62]: array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) In [59]: n_20<m_20
Out[59]: array([[ True, True, True, True, True, True, True, True, True, True]], dtype=bool) '''
#这里用mask过滤掉不符合的结果
mask = (np.expand_dims(np.arange(num_words), axis=0) < np.expand_dims(sequence_lengths, axis=1)) ###mask = array([[ True, True, True, True, True, True, True, True, True, True]], dtype=bool)
#序列的长度
total_labels = np.sum(sequence_lengths) print ("mask:",mask) print ("total_labels:",total_labels)
for i in range(1000):
#tf_unary_scores,tf_transition_params,_ = session.run([unary_scores,transition_params,train_op])
tf_viterbi_sequence,_=session.run([viterbi_sequence,train_op])
if i%100 == 0:
'''
false*false = false false*true= false ture*true = true
'''
#序列中预测对的个数
correct_labels = np.sum((y==tf_viterbi_sequence) * mask)
accuracy = 100.0*correct_labels/float(total_labels)
print ("Accuracy: %.2f%%" %accuracy)

tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()的更多相关文章

  1. (四) tensorflow笔记:常用函数说明

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  2. TensorFlow高层次机器学习API (tf.contrib.learn)

    TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格 ...

  3. tensorflow 笔记12:函数区别:placeholder,variable,get_variable,参数共享

    一.函数意义: 1.tf.Variable() 变量 W = tf.Variable(<initial-value>, name=<optional-name>) 用于生成一个 ...

  4. tensorflow笔记4:函数:tf.assign()、tf.assign_add()、tf.identity()、tf.control_dependencies()

    函数原型: tf.assign(ref, value, validate_shape=None, use_locking=None, name=None)   Defined in tensorflo ...

  5. [TensorFlow笔记乱锅炖] tf.multinomial(logits, num_samples)使用方法

    tf.multinomial(logits, num_samples) 第一个参数logits可以是一个数组,每个元素的值可以简单地理解为对应index的选择概率,注意这里的概率没有规定加起来的和为1 ...

  6. tensorflow笔记 :常用函数说明

    常用函数说明,梯度.产生变量等 http://blog.csdn.net/c2a2o2/article/details/69061539

  7. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  8. tensorflow笔记:使用tf来实现word2vec

    (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 (四) tensorflow笔 ...

  9. tf.contrib.layers.fully_connected参数笔记

    tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected(    inputs,    num_ou ...

随机推荐

  1. centos 7 系统启动不了 出现报错dependency failed for /mnt , dependency failed for local file systems

    阿里云一台Ecs重启后启动不了,出现报错 dependency failed for /mnt , dependency failed for local file systems ,  报错的原因  ...

  2. Delphi单元文件引用名称问题

    Delphi新版本的单元文件格式变化了,如windows变成了winapi.windows,如果想在单元引用中使用简称,则需要在工程选项中配置: 这样就可以使用全名或简写来引用单元了.

  3. MATLAB 的函数

    [需要注意]MATLAB函数不能先定义后调用! 如下为先定义后调用,结果报错: 错误: 文件:justTest2.m 行:88 列:5脚本中的函数定义必须出现在文件的结尾.请将 "mymax ...

  4. Python实现鸢尾花数据集分类问题——使用LogisticRegression分类器

    . 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,也可以处理多分类问题,它实际上是属于一种分类方法. 概率p与因变量往 ...

  5. Java多线程编程:Callable、Future和FutureTask浅析

    通过前面几篇的学习,我们知道创建线程的方式有两种,一种是实现Runnable接口,另一种是继承Thread,但是这两种方式都有个缺点,那就是在任务执行完成之后无法获取返回结果,那如果我们想要获取返回结 ...

  6. zabbix v3.0安装部署

    这篇文章没有写明init的部分要注意 zabbix v3.0安装部署 摘要: 本文的安装过程摘自http://www.ttlsa.com/以及http://b.lifec-inc.com ,和站长凉白 ...

  7. MVC2 扩展Models和自定义验证(学习笔记)

    当我们利用Visual Studio生成实体类以后,难免会用到验证功能(例如,用户登录时验证用户名是否为空,并加以显示). Visual Studio实体类:实体类 如果直接去编辑Visual Stu ...

  8. 【Oracle】Oracle中常用的系统函数

    Oracle SQL 提供了用于执行特定操作的专用函数.这些函数大大增强了 SQL 语言的功能.函数可以接受零个或者多个输入参数,并返回一个输出结果.在Oracle还可以自定义函数,关于更多信息可以查 ...

  9. Mac 下查看网络端口占用情况

    1.Mac 下查看网络端口占用情况 有的时候关闭了服务器,但是端口还是占用,解决的方法是 kill 掉占用该端口的进程. # 查看 8009 端口的占用情况 $ lsof -i:8009 可以看到,该 ...

  10. Swift 类

    1.类概念 类是一种抽象的概念,它实际上并不存在于现实中的时间和空间里. 在定义一个类时,我们就会在类中定义该类事物的特性和行为,从术语上讲,特性叫类的属性,行为叫类的方法. 类是一个静态的概念,类本 ...