1、知识点

"""
模拟一个y = 0.7x+0.8的案例 报警:
1、initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02
解决方法:由于使用了tf.initialize_all_variables() 初始化变量,该方法已过时,使用tf.global_variables_initializer()就不会了 tensorboard查看数据:
1、收集变量信息
tf.summary.scalar()
tf.summary.histogram()
merge = tf.summary.merge_all()
2、创建事件机制
fileWriter = tf.summary.FileWriter(logdir='',graph=sess.graph)
3、在sess中运行并合并merge
summary = sess.run(merge)
4、在循环训练中将变量添加到事件中
fileWriter.add_summary(summary,i) #i为训练次数 保存并加载训练模型:
1、创建保存模型saver对象
saver = tf.train.Saver()
2、保存模型
saver.save(sess,'./ckpt/model')
3、利用保存的模型加载模型,变量初始值从保存模型读取
if os.path.exists('./ckpt/checkpoint'):
saver.restore(sess,'./ckpt/model') 创建变量域:
with tf.variable_scope("data"):
"""

2、代码

# coding = utf-8

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '' def myLinear():
"""
自实现线性回归
:return:
"""
with tf.variable_scope("data"):
#1、准备数据
x = tf.random_normal((100,1),mean=0.5,stddev=1,name='x')
y_true = tf.matmul(x,[[0.7]])+0.8 #矩阵相乘至少为2维 with tf.variable_scope("model"):
#2、初始化权重和偏置
weight = tf.Variable(tf.random_normal((1,1)),name='w')
bias = tf.Variable(0.0,name='b')
y_predict = tf.matmul(x,weight)+bias with tf.variable_scope("loss"):
#3、计算损失值
loss = tf.reduce_mean(tf.square(y_true-y_predict)) with tf.variable_scope("train"):
#4、梯度下降优化loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss) #初始化变量
init_op = tf.global_variables_initializer() ############收集变量信息存到tensorboard查看###############
#收集变量
tf.summary.scalar('losses',loss)#1维
tf.summary.histogram('weight',weight) #高维
tf.summary.histogram('bias', bias) # 高维
merged = tf.summary.merge_all() #将变量合并
######################################################### #####################保存并加载模型###############
saver = tf.train.Saver()
#################################################
#5、循环训练
with tf.Session() as sess:
sess.run(init_op) #运行是初始化变量
if os.path.exists('./ckpt/checkpoint'):
saver.restore(sess,'./ckpt/model') #建立事件机制
fileWriter = tf.summary.FileWriter(logdir='./tmp',graph=sess.graph)
print("初始化权重为:%f,偏置为:%f" %(weight.eval(),bias.eval()))
for i in range(501):
summary = sess.run(merged) # 运行并合并
fileWriter.add_summary(summary,i)
sess.run(train_op)
if i%10==0 :
print("第%d次训练权重为:%f,偏置为:%f" % (i,weight.eval(), bias.eval()))
saver.save(sess,'./ckpt/model')
return None if __name__ == '__main__':
myLinear()

 3、代码

import tensorflow as tf
import csv
import numpy as np
import matplotlib.pyplot as plt
# 设置学习率
learning_rate = 0.01
# 设置训练次数
train_steps = 1000
with open('D:/Machine Learning/Data_wrangling/鲍鱼数据集.csv') as file:
reader = csv.reader(file)
a, b = [], []
for item in reader:
b.append(item[8])
del(item[8])
a.append(item)
file.close()
x_data = np.array(a)
y_data = np.array(b)
for i in range(len(x_data)):
y_data[i] = float(y_data[i])
for j in range(len(x_data[i])):
x_data[i][j] = float(x_data[i][j])
# 定义各影响因子的权重
weights = tf.Variable(np.ones([8,1]),dtype = tf.float32)
x_data_ = tf.placeholder(tf.float32, [None, 8])
y_data_ = tf.placeholder(tf.float32, [None, 1])
bias = tf.Variable(1.0, dtype = tf.float32)#定义偏差值
# 构建模型为:y_model = w1X1 + w2X2 + w3X3 + w4X4 + w5X5 + w6X6 + w7X7 + w8X8 + bias
y_model = tf.add(tf.matmul(x_data_ , weights), bias)
# 定义损失函数
loss = tf.reduce_mean(tf.pow((y_model - y_data_), 2))
#训练目标为损失值最小,学习率为0.01
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("Start training!")
lo = []
sample = np.arange(train_steps)
for i in range(train_steps):
for (x,y) in zip(x_data, y_data):
z1 = x.reshape(1,8)
z2 = y.reshape(1,1)
sess.run(train_op, feed_dict = {x_data_ : z1, y_data_ : z2})
l = sess.run(loss, feed_dict = {x_data_ : z1, y_data_ : z2})
lo.append(l)
print(weights.eval(sess))
print(bias.eval(sess))
# 绘制训练损失变化图
plt.plot(sample, lo, marker="*", linewidth=1, linestyle="--", color="red")
plt.title("The variation of the loss")
plt.xlabel("Sampling Point")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

tensorflow实现线性回归总结的更多相关文章

  1. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  2. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  3. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  4. 利用TensorFlow实现线性回归模型

    准备数据: import numpy as np import tensorflow as tf import matplotlib.pylot as plt # 随机生成1000个点,围绕在y=0. ...

  5. 如何用TensorFlow实现线性回归

    环境Anaconda 废话不多说,关键看代码 import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL']='2' tf.a ...

  6. TensorFlow多元线性回归实现

    多元线性回归的具体实现 导入需要的所有软件包:   因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏置结合起来.为此定义函数 appe ...

  7. TensorFlow实现线性回归模型代码

    模型构建 1.示例代码linear_regression_model.py #!/usr/bin/python # -*- coding: utf-8 -* import tensorflow as ...

  8. 学习TensorFlow,线性回归模型

    学习TensorFlow,在MNIST数据集上建立softmax回归模型并测试 一.代码 <span style="font-size:18px;">from tens ...

  9. tensorflow 学习1——tensorflow 做线性回归

    . 首先 Numpy: Numpy是Python的科学计算库,提供矩阵运算. 想想list已经提供了矩阵的形式,为啥要用Numpy,因为numpy提供了更多的函数. 使用numpy,首先要导入nump ...

随机推荐

  1. Java面向对象(一)

    面向对象(Object Oriented) 面向过程:事物比较简单.将问题分解为若干个步骤.按照步骤依次执行.面向对象:事物比较复杂.在解决面向对象的过程中,最后的执行部分还是面向过程方式,面向过程和 ...

  2. 微信小程序音乐播放器

    写在前面 1.入门几天小白的作品,希望为您有帮助,有好的意见或简易烦请赐教 2.微信小程序审核音乐类别已经下架,想要发布选题需慎重.附一个参考链接,感谢https://www.hishop.com.c ...

  3. Hadoop_11_HDFS的流式 API 操作

    对于MapReduce等框架来说,需要有一套更底层的API来获取某个指定文件中的一部分数据,而不是一整个文件 因此使用流的方式来操作 HDFS上的文件,可以实现读取指定偏移量范围的数据 1.客户端测试 ...

  4. CSS相对定位与绝对定位详解

    相对定位和绝对定位,不改变元素的大小形状,只改变元素的位置. 相对定位和绝对定位是通过position属性来控制的,position属性的值为下面几种: 值 描述 absolute 使元素绝对定位,相 ...

  5. angular 中同级元素交替样式

    事件 : ng-click="addNews()"  所属div的层级:    div > div  >span 即,对于 event.target 查找的话最多 从s ...

  6. 银行卡号Luhn校验算法

    /** *银行卡号Luhn校验算法 *luhn校验规则:16位银行卡号(19位通用): *1.将未带校验位的 15(或18)位卡号从右依次编号 1 到 15(18),位于奇数位号上的数字乘以 2. * ...

  7. Java-收邮件

    import java.util.Properties; import javax.mail.Folder; import javax.mail.Message; import javax.mail. ...

  8. redis配置主从备份以及主备切换方案配置(转)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/gsying1474/article/de ...

  9. 关于C# Dockpanel的一些入门的基本操作

    关于C# Dockpanel的一些入门的基本操作 原文链接:https://blog.csdn.net/Lc1996Jm/article/details/51881064 一.引用: 1.建立一个Wi ...

  10. 使用fiddler抓取jmeter发送的请求

    使用jmeter发送请求时,有时需要查看发送的请求是否合理,可以使用fiddler更直观的抓取并查看jmeter发送的请求.步骤如下:1.设置fidder-connections 端口号为8888 2 ...