深度学习原理与框架-Tensorflow基本操作-实现线性拟合
代码:使用tensorflow进行数据点的线性拟合操作
第一步:使用np.random.normal生成正态分布的数据
第二步:将数据分为X_data 和 y_data
第三步:对参数W和b, 使用tf.Variable()进行初始化,对于参数W,使用tf.random_normal([1], -1.0, 1.0)构造初始值,对于参数b,使用tf.zeros([1]) 构造初始值
第四步:使用W * X_data + b 构造出预测值y_pred
第五步:使用均分误差来表示loss损失值,即tf.reduce_mean(tf.square(y_data - y_pred))
第六步:使用opt = tf.train.GradientDescentOptimizer(0.5).minimize(loss) 梯度下架来降低损失值
第七步:循环,使用sess.run(opt) 执行梯度降低损失值的操作,并打印w,b和loss
第八步:进行作图操作,画出散点图和拟合好的曲线图
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf # 第一步:使用np.random.normal创建数据,即y = 0.1*x + 0.3
data = []
num_data = 1000
for i in range(num_data):
x_data = np.random.normal(0.0, 0.55)
y_data = 0.1 * x_data + 0.3 + np.random.normal(0.0, 0.03)
data.append([x_data, y_data]) # 第二步:将数据进行分配,分成特征和标签
X_data = [v[0] for v in data]
y_data = [v[1] for v in data] # 第三步:使用tf.Variable进行参数的初始化操作
W = tf.Variable(tf.random_normal([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]))
# 第四步:使用X_data * W + b 计算损失值
y_pred = X_data * W + b
# 第五步:使用均分误差来作为损失值
loss = tf.reduce_mean(tf.square(y_data - y_pred))
# 第六步:使用梯度下降来降低损失值
opt = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)
# 参数初始化操作
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(20):
# 第七步:循环,执行梯度下降操作,打印w,b和loss
sess.run(opt)
print('W:%g b:%g loss:%g'%(sess.run(W), sess.run(b), sess.run(loss))) # 第八步: 画图操作
plt.scatter(X_data, y_data, c='r')
plt.plot(X_data, sess.run(W) * X_data + sess.run(b))
plt.show()
深度学习原理与框架-Tensorflow基本操作-实现线性拟合的更多相关文章
- 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
- 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位
1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...
- 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量
1.tf.Variable([[1, 2]]) # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...
- 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)
1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...
- 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)
1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...
- 深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类
使用tensorflow构造神经网络用来进行mnist数据集的分类 相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数, 神经网络使用的损失值为 ...
- 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在
1. tf.nn.moments(x, axes=[0, 1, 2]) # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...
- 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)
1. sys.argv[1:] # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
随机推荐
- Windowsx64位安装pymssql并完成与数据库链接
常流程只需要打开下载并按照常规方法安装mssql包即可在程序中import pymssql,不过安装mssql确实有些小麻烦. 从开始安装就开始出现了各种异常错误 首先出现sqlfront.h文件找不 ...
- [转]jvm加载类规则
jvm包括三种类加载器: 第一种:bootstrap classloader:加载Java的核心类. 第二种:extension classloader:负责加载jre的扩展目录中的jar包. 第三种 ...
- python接口自动化20-requests获取响应时间(elapsed)与超时(timeout) ok试了 获取响应时间的
前言 requests发请求时,接口的响应时间,也是我们需要关注的一个点,如果响应时间太长,也是不合理的.如果服务端没及时响应,也不能一直等着,可以设置一个timeout超时的时间 关于request ...
- whith ~ as 用法
个人理解 with self.client.get("/", catch_response=True) as response: 其实就是 response = self.clie ...
- DP 01背包 七夕模拟赛
问题 D: 七夕模拟赛 时间限制: 1 Sec 内存限制: 128 MB提交: 60 解决: 23[提交][状态][讨论版] 题目描述 " 找啊找啊找GF,找到一个好GF,吃顿饭啊拉拉手 ...
- Java-Runoob-高级教程-实例-时间处理:03. Java 实例 - 获取年份、月份等
ylbtech-Java-Runoob-高级教程-实例-时间处理:03. Java 实例 - 获取年份.月份等 1.返回顶部 1. Java 实例 - 获取年份.月份等 Java 实例 以下实例演示 ...
- 【ZZ】号称“开发者神器”的GitHub,到底该怎么用?
号称“开发者神器”的GitHub,到底该怎么用? https://mp.weixin.qq.com/s/zpKOBMKWckY05Mv_B28RgQ A developer’s introductio ...
- [转][C#]拆分参数对
本文来自:https://www.jb51.net/article/62932.htm /// <summary> /// 分析 url 字符串中的参数信息 /// </summar ...
- MySQL mysqlbinlog企业案例
内容待补充 案例文字说明: 7.3 故障时间点: 周四上午10点,开发人员误删除了一个表,如何恢复? 7.4 思路: 1.停业务,避免数据的二次伤害 2.找一个临时库,恢复周三23:00全备 3.截取 ...
- 利用队列Queue实现一个多并发“线程池”效果的Socket程序
本例通过利用类Queue建立了一个存放着Thread对象的“容器对象”,当Client端申请与Server端通信时,在Server端的“链接循环”中每次拿出一个Thread对象去创建“线程链接”,从而 ...