import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np train_step = 5
train_path = 'train.csv'
is_train = False
learn_rate = 0.0001
epochs = 10 data = pd.read_csv(train_path) # 取部分特征字段用于分类,并将所有缺失的字段填充为0
data['Sex'] = data['Sex'].apply(lambda s: 1 if s == 'male' else 0)
data = data.fillna(0)
dataset_X = data[['Sex', 'Age', 'Pclass', 'SibSp', 'Parch', 'Fare']]
dataset_X = dataset_X.as_matrix() # 两种分类分别是幸存和死亡,'Survived'字段是其中一种分类的标签
# 新增'Deceased'字段表示第二种分类的标签,取值为'Survived'字段取非
data['Deceased'] = data['Survived'].apply(lambda s: int(not s))
dataset_Y = data[['Deceased', 'Survived']]
dataset_Y = dataset_Y.as_matrix() # 使用sklearn的train_test_split函数将标记数据切分为‘训练数据集和验证数据集’
# 将全部标记数据随机洗牌后切分,其中验证数据占20%,由test_size参数指定
X_train, X_test, Y_train, Y_test = train_test_split(dataset_X, dataset_Y,
test_size=0.2, random_state=42)
# 声明输入数据点位符
X = tf.placeholder(tf.float32, shape=[None, 6])
Y = tf.placeholder(tf.float32, shape=[None, 2])
# 声明变量(参数)
W = tf.Variable(tf.random_normal([6, 2]), name='weights')
b = tf.Variable(tf.zeros([2]), name='bias')
# 构造前向传播计算图
y_pred = tf.nn.softmax(tf.matmul(X, W) + b) # 使用交叉熵作为代价函数 Y * log(y_pred + e-10),程序中e-10,防止y_pred十分接近0或者1时,
# 计算(log0)会得到无穷,导致非法,进一步导致无法计算梯度,迭代陷入崩溃。
cross_entropy = -tf.reduce_sum(Y * tf.log(y_pred + 1e-10), reduction_indices=1)
# 批量样本的代价为所有样本交叉熵的平均值
cost = tf.reduce_mean(cross_entropy)
# 使用随机梯度下降算法优化器来最小化代价,系统自动构建反向传播部分的计算图
train_op = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) saver = tf.train.Saver()
if is_train:
with tf.Session() as sess:
writer = tf.summary.FileWriter('logfile', sess.graph)
# 初始化所有变量,必须最先执行
tf.global_variables_initializer().run()
# 以下为训练迭代,迭代10轮
for epoch in range(10):
total_loss = 0
for i in range(len(X_train)):
_, loss = sess.run([train_op, cost], feed_dict={X:[X_train[i]], Y:[Y_train[i]]})
total_loss += loss
print('Epoch: %04d, total loss=%.9f' % (epoch + 1, total_loss))
# 保存model
if (epoch + 1) % train_step == 0:
save_path = saver.save(sess, './model/model.ckpt', global_step=epoch + 1)
print('Training complete!')
pred = sess.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy)
else:
# 恢复model,继续训练
with tf.Session() as sess1:
# 从'checkpoint'文件中读出最新存档的路径
ckpt = tf.train.get_checkpoint_state('./model')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess1, ckpt.model_checkpoint_path)
print('restore model sucess!')
else:
sys(0)
print('continue train …………')
for epoch in range(epochs):
total_loss = 0
for i in range(len(X_train)):
_, loss = sess1.run([train_op, cost], feed_dict={X:[X_train[i]], Y:[Y_train[i]]})
total_loss += loss
print('Epoch: %04d, total loss=%.9f' % (epoch + 1, total_loss))
# 保存model
if (epoch + 1) % train_step == 0:
save_path = saver.save(sess1, './model/model.ckpt', global_step=epoch + 1)
print('Training complete!')
pred = sess1.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy) # 恢复model参数
with tf.Session() as sess2:
# 从'checkpoint'文件中读出最新存档的路径
print('restore lastest model, compute Accuracy!')
ckpt = tf.train.get_checkpoint_state('./model')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess2, ckpt.model_checkpoint_path)
pred = sess2.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy)

TensorFlow自带的可视化工具TensorBoard

在当前目录的命令行下键入:tensorboard --logdir=logfile

根据命令行的提示,在浏览器里输入相应的网址。

TensorFlow入门-Tianic数据集训练的更多相关文章

  1. 搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型

    原文地址:搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型 0x00 环境 OS: Ubuntu 1810 x64 Anaconda: 4.6.12 P ...

  2. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  3. FaceRank,最有趣的 TensorFlow 入门实战项目

    FaceRank,最有趣的 TensorFlow 入门实战项目 TensorFlow 从观望到入门! https://github.com/fendouai/FaceRank 最有趣? 机器学习是不是 ...

  4. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  5. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  6. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  7. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  8. 分享《机器学习实战基于Scikit-Learn和TensorFlow》中英文PDF源代码+《深度学习之TensorFlow入门原理与进阶实战》PDF+源代码

    下载:https://pan.baidu.com/s/1qKaDd9PSUUGbBQNB3tkDzw <机器学习实战:基于Scikit-Learn和TensorFlow>高清中文版PDF+ ...

  9. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

随机推荐

  1. 罗技 M558 鼠标维修记录

    罗技 M558 鼠标维修记录 故障现象 按键不灵敏 拆机内部图 前进键 后退键 左键 右键 中键 自定义功能键 使用的是 OMRON 按键,好东西,质量可以. 但毕竟是机械的东西,还是有老化,用万用表 ...

  2. python findall() re.S

    官方文档:https://docs.python.org/3.6/library/re.html 教程:http://www.regexlab.com/zh/regref.htm re.findall ...

  3. css 填坑常用代码分享[居家实用型]

    原文地址 http://www.cnblogs.com/jikey/p/4233003.html 以下是常用的代码收集,没有任何技术含量,只是填坑的积累.转载请注明出处,谢谢. 一. css 2.x ...

  4. autoit 中文API

    中文API 参考地址: http://www.jb51.net/shouce/autoit/ 虫师的selelnium里面也有简单的说 环境搭建+上传弹窗的小案例

  5. Spring 3.1新特性之一:spring注解之@profile

    前言 由于在项目中使用Maven打包部署的时候,经常由于配置参数过多(比如Nginx服务器的信息.ZooKeeper的信息.数据库连接.Redis服务器地址等),导致实际现网的配置参数与测试服务器参数 ...

  6. 学习git最好的方式

    1:登陆git官网网站 https://git-scm.com 2:点击esay to learn连接 3:点击Book连接 4:选择简体中文,下载PDF文档,也可以在线学习.

  7. Centos 6.5将光盘作为yum源的设置方法

    Centos 6.5将光盘作为yum源的设置方法 在使用Centos 的时候,用yum来安装软件包是再方便不过了,但是如果在无法连接互联网的情况下,yum就不好用了. 下面介绍一种方式,就是将Cent ...

  8. JedisCluster模式尝试进行批量操作

    搭建完redis集群后,可以通过jedis的JedisCluster来访问Redis集群,这里列出使用jedisCluster的spring bean配置方式:   <bean id=" ...

  9. 用PNG作为Texture创建Material

    转自:http://aigo.iteye.com/blog/2279512 1,导入一张png素材作为Texture 2,新建一个Material,设置Blend Mode为Translucent,连 ...

  10. Logistic回归的两种形式y=0/1,y=+1/-1

    第一种形式:y=0/1 第二种形式:y=+1/-1 第一种形式的损失函数可由极大似然估计推出: 第二种形式的损失函数:  , 参考:https://en.wikipedia.org/wiki/Loss ...