线性回归问题

 # encoding: utf-8

 import numpy as np
import matplotlib.pyplot as plt data = []
for i in range(100):
x = np.random.uniform(-10., 10.) #均匀分布产生x
eps = np.random.normal(0., 0.01) #高斯分布产生一个误差值
y = 1.477*x + 0.089 +eps #计算得到y值
data.append([x, y]) #保存到data中 data = np.array(data) #转成数组的方式方便处理
plt.plot(data[:,0], data[:,1], 'b') #自己引入用于观察原始数据
#plt.show()
plt.savefig('original data.png') def mse(b, w, points): #计算所有点的预测值和真实值之间的均方误差
totalError = 0
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
totalError += (y -(w*x + b))**2 #真实值减预测值的平方
return totalError/float(len(points)) #返回平均误差值 def step_gradient(b_current, w_current, points, lr): #预测模型中梯度下降方式优化b和w
b_gradient = 0
w_gradient = 0
M = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
b_gradient += (2/M) * ((w_current*x + b_current) - y) #求偏导数的公式可知
w_gradient += (2/M)*x*((w_current*x + b_current) - y) #求偏导数的公式可知
new_b = b_current - (lr*b_gradient) #更新参数,使用了梯度下降法
new_w = w_current - (lr*w_gradient) #更新参数,使用了梯度下降法
return [new_b, new_w] def gradient_descent(points, starting_b, starting_w, lr, num_iterations): #循环更新w,b多次
b = starting_b
w = starting_w
loss_data = []
for step in range(num_iterations): #计算并更新一次
b, w = step_gradient(b, w, np.array(points), lr) #更新了这一次的b,w
loss = mse(b, w, points)
loss_data.append([step+1, loss])
if step % 50 == 0: #每50次输出一回
print(f"iteration:{step}, loss{loss}, w:{w}, b:{b}")
return [b, w, loss_data] def main():
lr = 0.01 #学习率,梯度下降算法中的参数
initial_b = 0 #初值
initial_w = 0
num_iterations = 1000 #学习100轮
[b, w, loss_data] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
loss = mse(b, w, data)
print(f'Final loss:{loss}, w:{w}, b:{b}') plt.figure() #观察loss每一步情况
loss_data = np.array(loss_data)
plt.plot(loss_data[:,0], loss_data[:,1], 'g')
plt.savefig('loss.png')
#plt.show() plt.figure() #观察最终的拟合效果
y_fin = w*data[:,0] + b + eps
plt.plot(data[:,0], y_fin, 'r')
#plt.show()
plt.savefig('final data.png') if __name__ == '__main__':
main()

original data (y = w*x + b +eps)

loss rate

final data (y' = w' *x + b' + eps )

最终loss趋近9.17*10^-5, w趋近1.4768, b趋近0.0900

真实的w值1.477, b为0.089

对于线性回归问题,适用性挺好!

主要的数学代码能理解,唯有取梯度的反方向更新参数,不是很能理解!

这里还没有用到tensorflow,下一次更新基础知识!

tensorflow2.0 学习(二)的更多相关文章

  1. tensorflow2.0 学习(三)

    用tensorflow2.0 版回顾了一下mnist的学习 代码如下,感觉这个版本下的mnist学习更简洁,更方便 关于tensorflow的基础知识,这里就不更新了,用到什么就到网上取搜索相关的知识 ...

  2. tensorflow2.0 学习(一)

    虽说是按<TensorFlow深度学习>这本书来学习的,但是总会碰到新的问题!记录下这些问题,有利于巩固知新. 之前学过一些tensorflow1.0的知识,到RNN这章节,后面没有再继续 ...

  3. Tensorflow2.0学习(一)

    站长资讯平台:今天学习一下Tensorflow2.0 的基础 核心库,@tf.function ,可以方便的将动态图的语言,变成静态图,在某种程度上进行计算加速 TensorFlow Lite Ten ...

  4. tensorflow2.0学习笔记

    今天我们开始学习tensorflow2.0,用一种简单和循循渐进的方式,带领大家亲身体验深度学习.学习的目录如下图所示: 1.简单的神经网络学习过程 1.1张量生成 1.2常用函数 1.3鸢尾花数据读 ...

  5. TensorFlow2.0(二):数学运算

    1 基本运算:加(+).减(-).点乘(*).除(/).地板除法(//).取余(%) 基本运算中所有实例都以下面的张量a.b为例进行: >>> a = tf.random.unifo ...

  6. tensorflow2.0学习笔记第一章第二节

    1.2常用函数 本节目标:掌握在建立和操作神经网络过程中常用的函数 # 常用函数 import tensorflow as tf import numpy as np # 强制Tensor的数据类型转 ...

  7. tensorflow2.0学习笔记第二章第四节

    2.4损失函数损失函数(loss):预测值(y)与已知答案(y_)的差距 nn优化目标:loss最小->-mse -自定义 -ce(cross entropy)均方误差mse:MSE(y_,y) ...

  8. tensorflow2.0学习笔记第二章第一节

    2.1预备知识 # 条件判断tf.where(条件语句,真返回A,假返回B) import tensorflow as tf a = tf.constant([1,2,3,1,1]) b = tf.c ...

  9. tensorflow2.0学习笔记第一章第一节

    一.简单的神经网络实现过程 1.1张量的生成 # 创建一个张量 #tf.constant(张量内容,dtpye=数据类型(可选)) import tensorflow as tf import num ...

随机推荐

  1. MNIST机器学习入门(二)

    在前一个博客中,我们已经对MNIST 数据集和TensorFlow 中MNIST 数据集的载入有了基本的了解.本节将真正以TensorFlow 为工具,写一个手写体数字识别程序,使用的机器学习方法是S ...

  2. Drool7s 什么叫KIE和生命周期-系列03课

    KIE是缩写,knowledge is everything.可以理解成一个上层接口,本质是由很多个实现类去实现功能的. 另外关于drool7s的生命周期,请看下图 本文只是让你了解drools7的一 ...

  3. 【Spring-AOP-学习笔记】

    http://outofmemory.cn/java/spring/spring-DI-with-annotation-context-component-scan https://www.cnblo ...

  4. C#录制屏幕采集系统桌面画面

    在项目中,有很多需要录制屏幕的场景,比如直播课,录制教学视频等场景.但.NET自带的Screen类功能比较弱,效率很低.那么如何简单快捷地高效采集桌面屏幕呢?当然是采用SharpCapture!下面开 ...

  5. Nginx fastcgi_cache权威指南

    一.简介 Nginx版本从0.7.48开始,支持了类似Squid的缓存功能.这个缓存是把URL及相关组合当做Key,用Md5算法对Key进行哈希,得到硬盘上对应的哈希目录路径,从而将缓存内容保存在该目 ...

  6. python class 中__next__用法

    class A(): def __init__(self,b): self.b=b # def __iter__(self): # 这个函数可以用,表示迭代标志,但也可以省略 # return sel ...

  7. 缓冲区溢出漏洞 ms04011

    DSScan使用 扫描目标主机是否存在ms04011漏洞 getos使用 获取操作系统类型 > getos.exe 192.168.1.101 ------------------------- ...

  8. 雪妖现世:给SAP Fiori Launchpad增添雪花纷飞的效果

    1995年7月,台湾大宇公司发布了一款国产单机角色扮演游戏神作:<仙剑奇侠传1>,所谓"一包烟,一杯茶",就能在电脑面前坐一整天. 这么经典的游戏Jerry当然已经通关 ...

  9. 【书评:Oracle查询优化改写】第四章

    [书评:Oracle查询优化改写]第四章 BLOG文档结构图 一.1 导读 各位技术爱好者,看完本文后,你可以掌握如下的技能,也可以学到一些其它你所不知道的知识,~O(∩_∩)O~: ① check的 ...

  10. idea 启动项目报错,more than one fragment with the name [spring web] was found

    这是由于idea导入项目的时候有多个模块,并且有多个web.xml导致的,先删除对应的模块,后启动即可. 另外也有可能是spring的jar冲突,把冲突的jar删除即可.