神经网络学习中的损失函数及mini-batch学习
# 损失函数(loss function)。这个损失函数可以使用任意函数,# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等一切都在代码时有注释哈。
import numpy as np
from minst import load_mnist
# 损失函数(loss function)。这个损失函数可以使用任意函数,
# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等
# 均方误差会计算神经网络的输出和正确解监督数据的各个元素之差的平方,再求总和
def mean_quared_error(y, t):
return 0.5 * np.sum((y-t)**2)
# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))
def cross_entropy_error(y, t):
# 保护性对策,添加一个微小值delta可以防止负无限大的发生
delta = 1e-7
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
batch_size = y.shape[0]
# t 为 one-hot 表示
return -np.sum(t * np.log(y+delta)) / batch_size
# t 为标签形式时
# return -np.sum(np.log(y[np.arange(batch_size), t] + delta)) / batch_size
# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))
# 当数据集的训练数据有很大时,如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。
# 再者,如果遇到大数据,数据量会有几百万、几千万之多,这种情况下以全部数据为对象计算损失函数是不现实的。
# 因此,我们从全部数据中选出一部分,作为全部数据的“近似”。
# 神经网络的学习也是从训练数据中选出一批数据(称为mini-batch,小批量),然后对每个mini-batch进行学习。
# 比如,从60000个训练数据中随机选择100笔,再用这100笔数据进行学习。
# 这种学习方式称为mini-batch学习。
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
print(x_batch)
print(t_batch)
C:\Python36\python.exe C:/Users/Sahara/PycharmProjects/test1/test.py C:\Users\Sahara\PycharmProjects\test1 0.09750000000000003 0.5975 0.510825457099338 2.302584092994546 (60000, 784) (60000, 10) [[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]] [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]] Process finished with exit code 0
神经网络学习中的损失函数及mini-batch学习的更多相关文章
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- Scratch学习中需要注意的地方,学习Scratch时需要注意的地方
在所有的编程工具中,Scratch是比较简单的,适合孩子学习锻炼,也是信息学奥赛的常见项目.通常Scratch学习流程是,先掌握程序相关模块,并且了解各个模块的功能使用,然后通过项目的编写和练习,不断 ...
- 神经网络训练中的Tricks之高效BP(反向传播算法)
神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...
- 关于Linux学习中的问题和体会
本科期间未开展过与之相关的课程,所以初次接触Linux难免有些问题!参照老师给的学习资料中内容,逐步解决了一些问题,但还有一些问题没解决,下面列举出自己遇到的一些问题. 1.在环境变量与文件查找专题中 ...
- 【转载】深度学习中softmax交叉熵损失函数的理解
深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning 版权声明:本文为博主原 ...
- 深度学习中的batch、epoch、iteration的含义
深度学习的优化算法,说白了就是梯度下降.每次的参数更新有两种方式. 第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度.这种方法每更新一次参数都要把数据集里的所有样本都看一遍, ...
- 转载: scikit-learn学习之K-means聚类算法与 Mini Batch K-Means算法
版权声明:<—— 本文为作者呕心沥血打造,若要转载,请注明出处@http://blog.csdn.net/gamer_gyt <—— 目录(?)[+] ================== ...
- 深度学习中 Batch Normalization
深度学习中 Batch Normalization为什么效果好?(知乎) https://www.zhihu.com/question/38102762
- 一文读懂神经网络训练中的Batch Size,Epoch,Iteration
一文读懂神经网络训练中的Batch Size,Epoch,Iteration 作为在各种神经网络训练时都无法避免的几个名词,本文将全面解析他们的含义和关系. 1. Batch Size 释义:批大小, ...
随机推荐
- [LeetCode] 110. Balanced Binary Tree 平衡二叉树
Given a binary tree, determine if it is height-balanced. For this problem, a height-balanced binary ...
- 可扩展标记语言XML之一:XML的概念、作用与示例
哈喽大家好啊,乐字节小乐又来给大家分享Java技术文章了.上次已经讲完了Java多线程相关知识(可以看我博客文章), 这次文章将讲述可扩展标记语言XML 一. 标记语言 标记语言,是一种将文本(Tex ...
- 根据GSVA结果绘制不同组的趋势图
首先需要将GSVA的矩阵结果转换成如下格式: 然后使用如下代码进行作图 infile <- "draw_pre_violin_heatmap.txt" data <- ...
- java email发送(附件中文的处理)
这里使用的是commons-email-1.3.2.jar进行的开发,自认为这是简单的邮件发送. package com.yt.base.common; import java.io.Unsuppor ...
- Django中的admin
1.基本知识 在用Django框架写了一个网站之后,我们添加数据大概有两种方式: 1.在连接的数据库中添加数据 2.登录admin,进入后台添加数据 创建一个Django项目后,我们在url.py中会 ...
- python基础 — 循环重新输入
后续完善各种循环案例 while True: try: str_num = input('input a number:') num = float(str_num) print("你输入的 ...
- Crazy Binary String(前缀和)(2019牛客暑期多校训练营(第三场))
示例: 输入: 801001001 输出:4 6 题意:一段长度为n且只有 ‘0’ 和 ‘1’ 的字符串,求子串中 ‘0’ 和 ‘1’ 数目相等和子序列中 ‘0’ 和 ‘1’ 数目相等的最大长度. 思 ...
- VC6.0- C语言-winsocket-警告warning C4761
错误介绍 操作系统:windows10 IDE:VC6.0 语言:C语言 项目内容简介:编写一个双人网络海战棋对战游戏 警告类型:警告warning C4761 integral size misma ...
- SpringCloud之Hystrix集群及集群监控turbine
目的: Hystrix集群及监控turbine Feign.Hystrix整合之服务熔断服务降级彻底解耦 集群后超时设置 Hystrix集群及监控turbine 新建一个springboot工程mic ...
- Xshell连接虚拟机文档教程
1打开VirtualBox 2 找到导入的虚拟机 3右键虚拟机 启动 4 等待加载 5 加载的时候,打开xshell 6 7 填写框住的内容 名称: 自己取 主机: 127.0.0.1 固定内容 端 ...