tensorflow sequence_loss
sequence_loss是nlp算法中非常重要的一个函数.rnn,lstm,attention都要用到这个函数.看下面代码:
# coding: utf-8
import numpy as np
import tensorflow as tf
from tensorflow.contrib.seq2seq import sequence_loss
logits_np = np.array([
[[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
[[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]
])
targets_np = np.array([
[0, 0, 0],
[0, 0, 0]
], dtype=np.int32)
logits = tf.convert_to_tensor(logits_np)
targets = tf.convert_to_tensor(targets_np)
cost = sequence_loss(logits=logits,
targets=targets,
weights=tf.ones_like(targets, dtype=tf.float64))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
r = sess.run(cost)
print(r)
先对每个[0.5,0.5,0.5,0.5]取softmax. softmax([0.5,0.5,0.5,0.5])=(0.25,0.25,0.25,0.25)然后再计算-ln(0.25)*6/6=1.38629436112.
再看一个例子
# coding:utf-8
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from tensorflow.contrib.seq2seq import sequence_loss
import tensorflow as tf
import numpy as np
output_np = np.array(
[
[[0.6, 0.5, 0.3, 0.2], [0.9, 0.5, 0.3, 0.2], [1.0, 0.5, 0.3, 0.2]],
[[0.2, 0.5, 0.3, 0.2], [0.3, 0.5, 0.3, 0.2], [0.4, 0.5, 0.3, 0.2]]
]
)
print(output_np.shape)
target_np = np.array([[0, 1, 2],
[3, 0, 1]],
dtype=np.int32)
print(target_np.shape)
output = tf.convert_to_tensor(output_np, np.float32)
target = tf.convert_to_tensor(target_np, np.int32)
cost = sequence_loss(output,
target,
tf.ones_like(target, dtype=np.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
cost_r = sess.run(cost)
print(cost_r)
这个代码作用和下面的tf.reduce_mean(softmax_cross_entropy_with_logits)作用一致.
# coding:utf-8
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
def to_onehot(a):
max_index = np.max(a)
b = np.zeros((a.shape[0], max_index + 1))
b[np.arange(a.shape[0]), a] = 1
return b
logits_ph = tf.placeholder(tf.float32, shape=(None, None))
labels_ph = tf.placeholder(tf.float32, shape=(None, None))
output_np = np.array([
[0.6, 0.5, 0.3, 0.2],
[0.9, 0.5, 0.3, 0.2],
[1.0, 0.5, 0.3, 0.2],
[0.2, 0.5, 0.3, 0.2],
[0.3, 0.5, 0.3, 0.2],
[0.4, 0.5, 0.3, 0.2]])
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_ph, logits=logits_ph))
target_np = np.array([0, 1, 2, 3, 0, 1])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
cost_r = sess.run(cost, feed_dict={logits_ph: output_np, labels_ph: to_onehot(target_np)})
print(cost_r)
再取交叉熵,再取平均.
tensorflow sequence_loss的更多相关文章
- tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
#!/usr/bin/env python # -*- coding: utf-8 -*- import tensorflow as tf import numpy as np params=np.r ...
- 学习笔记CB014:TensorFlow seq2seq模型步步进阶
神经网络.<Make Your Own Neural Network>,用非常通俗易懂描述讲解人工神经网络原理用代码实现,试验效果非常好. 循环神经网络和LSTM.Christopher ...
- Tensorflow动态seq2seq使用总结(r1.3)
https://www.jianshu.com/p/c0c5f1bdbb88 动机 其实差不多半年之前就想吐槽Tensorflow的seq2seq了(后面博主去干了些别的事情),官方的代码已经抛弃原来 ...
- sequence_loss的解释
在做seq2seq的时候,经常需要使用sequence_loss这是损失函数. 现在分析一下sequence_loss这个函数到底在做什么 # coding: utf-8 import numpy a ...
- 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:使用TensorFlow和Keras开发高级自然语言处理系统——LSTM网络原理以及使用LSTM实现人机问答系统
!mkdir '/content/gdrive/My Drive/conversation' ''' 将文本句子分解成单词,并构建词库 ''' path = '/content/gdrive/My D ...
- Tensorflow 官方版教程中文版
2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...
- tensorflow学习笔记二:入门基础
TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...
- 用Tensorflow让神经网络自动创造音乐
#————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...
- tensorflow 一些好的blog链接和tensorflow gpu版本安装
pading :SAME,VALID 区别 http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...
随机推荐
- shell的运用 : jenkins 编译 打包前端发布 生产(tomcat)
生产隔离做得非常.....文件上传只能通过固定ip机器的sftp账户上传,账户密码每个月要写申请才能获得. 登陆生产服务只能通过浏览器登陆!!! 发布一次生产,很痛苦. 做了简单的shell来减轻痛苦 ...
- 【微信小程序】踩坑指南(持续更新)
前言 说明: 基于mpvue框架:mpvue官方文档 语法同vue框架:vue官方文档 小程序中会有一些坑点,这里会就工作中遇到的坑一一列举出来 无说明时请直接看代码注释 v-show无法使用在小程序 ...
- [javascript] Javascript的笔记
1.2019年10月20日12:28:16,学习HOW2J的Javascript, 2.一般见到的缩写js,就是javascript的意思: 3.javascript代码必须放在script标签中,s ...
- saprk性能调优参考
1.Tuning Spark 文档 原文:http://spark.apache.org/docs/latest/tuning.html 翻译参考:https://www.cnblogs.com/lh ...
- pat 1013 Battle Over Cities(25 分) (并查集)
1013 Battle Over Cities(25 分) It is vitally important to have all the cities connected by highways i ...
- ES6的基础知识(一)
1.ECMAScript 6.0(以下简称ES6). 2.ECMAScript 和 JavaScript 的关系是,前者是后者的规格,后者是前者的其中一种实现. 3.对ES6支持的浏览器:超过 90% ...
- PHP变量的初始化以及赋值方式介绍
什么是变量 变量通俗的来说是一种容器.根据变量类型不同,容器的大小不一样,自然能存放的数据大小也不相同.在变量中存放的数据,我们称之为变量值. PHP 中的变量用一个美元符号后面跟变量名来表示.变量名 ...
- 面试官:CPU百分百!给你一分钟,怎么排查?有几种方法?
Part0 遇到了故障怎么办? 在生产上,我们会遇到各种各样的故障,遇到了故障怎么办? 不要慌,只有冷静才是解决故障的利器. 下面以一个例子为例,在生产中碰到了CPU 100%的问题怎么办? 在生产中 ...
- 结合开源软件kaptcha讲解登录验证码功能的实现
一.验证码生成之配置使用kaptcha 使用google开源的验证码实现类库kaptcha,通过maven坐标引入 <dependency> <groupId>com.gith ...
- Scala学习系列一
一 scala介绍 Scala是一门以java虚拟机(JVM)为目标运行环境并将面向对象和函数式编程的最佳特性结合在一起的静态类型编程语言. 1) Scala 是一门多范式 (multi-parad ...