# 损失函数(loss function)。这个损失函数可以使用任意函数,# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等一切都在代码时有注释哈。
import numpy as np
from minst import load_mnist

# 损失函数(loss function)。这个损失函数可以使用任意函数,
# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等

# 均方误差会计算神经网络的输出和正确解监督数据的各个元素之差的平方,再求总和
def mean_quared_error(y, t):
    return 0.5 * np.sum((y-t)**2)

# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))

def cross_entropy_error(y, t):
    # 保护性对策,添加一个微小值delta可以防止负无限大的发生
    delta = 1e-7
    if y.ndim == 1:
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)
    batch_size = y.shape[0]
    # t 为 one-hot 表示
    return -np.sum(t * np.log(y+delta)) / batch_size
    #  t 为标签形式时
    # return -np.sum(np.log(y[np.arange(batch_size), t] + delta)) / batch_size

# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))

# 当数据集的训练数据有很大时,如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。
# 再者,如果遇到大数据,数据量会有几百万、几千万之多,这种情况下以全部数据为对象计算损失函数是不现实的。
# 因此,我们从全部数据中选出一部分,作为全部数据的“近似”。
# 神经网络的学习也是从训练数据中选出一批数据(称为mini-batch,小批量),然后对每个mini-batch进行学习。
# 比如,从60000个训练数据中随机选择100笔,再用这100笔数据进行学习。
# 这种学习方式称为mini-batch学习。

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
print(x_batch)
print(t_batch)
C:\Python36\python.exe C:/Users/Sahara/PycharmProjects/test1/test.py
C:\Users\Sahara\PycharmProjects\test1
0.09750000000000003
0.5975
0.510825457099338
2.302584092994546
(60000, 784)
(60000, 10)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]

Process finished with exit code 0

  

神经网络学习中的损失函数及mini-batch学习的更多相关文章

  1. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  2. Scratch学习中需要注意的地方,学习Scratch时需要注意的地方

    在所有的编程工具中,Scratch是比较简单的,适合孩子学习锻炼,也是信息学奥赛的常见项目.通常Scratch学习流程是,先掌握程序相关模块,并且了解各个模块的功能使用,然后通过项目的编写和练习,不断 ...

  3. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  4. 关于Linux学习中的问题和体会

    本科期间未开展过与之相关的课程,所以初次接触Linux难免有些问题!参照老师给的学习资料中内容,逐步解决了一些问题,但还有一些问题没解决,下面列举出自己遇到的一些问题. 1.在环境变量与文件查找专题中 ...

  5. 【转载】深度学习中softmax交叉熵损失函数的理解

    深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning   版权声明:本文为博主原 ...

  6. 深度学习中的batch、epoch、iteration的含义

    深度学习的优化算法,说白了就是梯度下降.每次的参数更新有两种方式. 第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度.这种方法每更新一次参数都要把数据集里的所有样本都看一遍, ...

  7. 转载: scikit-learn学习之K-means聚类算法与 Mini Batch K-Means算法

    版权声明:<—— 本文为作者呕心沥血打造,若要转载,请注明出处@http://blog.csdn.net/gamer_gyt <—— 目录(?)[+] ================== ...

  8. 深度学习中 Batch Normalization

    深度学习中 Batch Normalization为什么效果好?(知乎) https://www.zhihu.com/question/38102762

  9. 一文读懂神经网络训练中的Batch Size,Epoch,Iteration

    一文读懂神经网络训练中的Batch Size,Epoch,Iteration 作为在各种神经网络训练时都无法避免的几个名词,本文将全面解析他们的含义和关系. 1. Batch Size 释义:批大小, ...

随机推荐

  1. 前端HTML学习心得

    学习最好的效果就是理论加实践--Hanks!!!(给大家打鸡血的哈哈哈) 前面的学习我教大家怎么搭建简单的前端开发环境,现在我教大家怎么使用工具学习(从入门到放弃哈哈,不不不,这是以前的我,现在我下了 ...

  2. FDD与TDD的区别

    LTE通常分为FDD LTE和TDD LTEFDD,频分双工(Frequency Division Duplexing)我和你通信,像广播一样,只能我说你听,是单工:像对讲机一样,同一时间只能一方说, ...

  3. c++11 standardized memory model 内存模型

    C++11 标准中引入了内存模型,其目的是为了解决多线程中可见性和顺序(order).这是c++11最重要的新特征,标准忽略了平台的差异,从语义层面规定了6种内存模型来实现跨平台代码的兼容性.多线程代 ...

  4. Spring MVC原理图及其重要组件

    一.Spring MVC原理图: ps: springmvc的运行流程为图中数字序号 二.springmvc的重要组件: 1)前端控制器:DispatchServlet(不需要程序员开发) 接收请求, ...

  5. 13 Spring 的事务控制

    1.事务的概念 理解事务之前,先讲一个你日常生活中最常干的事:取钱.  比如你去ATM机取1000块钱,大体有两个步骤:首先输入密码金额,银行卡扣掉1000元钱:然后ATM出1000元钱.这两个步骤必 ...

  6. vue通过ajax加载json数据

    HTML <ul id="Hanapp"> <li class="styVue" v-for="item in actList&qu ...

  7. 【计算机网络基础】URI、URN和URL的区别

    先引用一张关系图 灰色部分为URI URI强调的是给资源标记命名,URL强调的是给资源定位. URI是Uniform Resource Identifier,表示是一个资源: URL是Uniform ...

  8. hdu 2841 题解

    题目 题意:就是问在一个$ n* m $的矩阵中站在 $ (0,0) $ 能看到几个整数点. 很明显如果有两个平行向量 $ \vec{a}=(x_1,y_1) $ ,$ \vec{b}=(x_2,y_ ...

  9. 图片url地址的生成获取方法

    在写博客插入图片时,许多时候需要提供图片的url地址.作为菜鸡的我,自然是一脸懵逼.那么什么是所谓的url地址呢?又该如何获取图片的url地址呢? 首先来看一下度娘对url地址的解释:url是统一资源 ...

  10. gorm 处理时间戳

    问题 在使用 gorm 的过程中, 处理时间戳字段时遇到问题.写时间戳到数据库时无法写入. 通过查阅资料最终问题得以解决,特此总结 设置数据库的 dsn parseTime = "True& ...