import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np train_step = 5
train_path = 'train.csv'
is_train = False
learn_rate = 0.0001
epochs = 10 data = pd.read_csv(train_path) # 取部分特征字段用于分类,并将所有缺失的字段填充为0
data['Sex'] = data['Sex'].apply(lambda s: 1 if s == 'male' else 0)
data = data.fillna(0)
dataset_X = data[['Sex', 'Age', 'Pclass', 'SibSp', 'Parch', 'Fare']]
dataset_X = dataset_X.as_matrix() # 两种分类分别是幸存和死亡,'Survived'字段是其中一种分类的标签
# 新增'Deceased'字段表示第二种分类的标签,取值为'Survived'字段取非
data['Deceased'] = data['Survived'].apply(lambda s: int(not s))
dataset_Y = data[['Deceased', 'Survived']]
dataset_Y = dataset_Y.as_matrix() # 使用sklearn的train_test_split函数将标记数据切分为‘训练数据集和验证数据集’
# 将全部标记数据随机洗牌后切分,其中验证数据占20%,由test_size参数指定
X_train, X_test, Y_train, Y_test = train_test_split(dataset_X, dataset_Y,
test_size=0.2, random_state=42)
# 声明输入数据点位符
X = tf.placeholder(tf.float32, shape=[None, 6])
Y = tf.placeholder(tf.float32, shape=[None, 2])
# 声明变量(参数)
W = tf.Variable(tf.random_normal([6, 2]), name='weights')
b = tf.Variable(tf.zeros([2]), name='bias')
# 构造前向传播计算图
y_pred = tf.nn.softmax(tf.matmul(X, W) + b) # 使用交叉熵作为代价函数 Y * log(y_pred + e-10),程序中e-10,防止y_pred十分接近0或者1时,
# 计算(log0)会得到无穷,导致非法,进一步导致无法计算梯度,迭代陷入崩溃。
cross_entropy = -tf.reduce_sum(Y * tf.log(y_pred + 1e-10), reduction_indices=1)
# 批量样本的代价为所有样本交叉熵的平均值
cost = tf.reduce_mean(cross_entropy)
# 使用随机梯度下降算法优化器来最小化代价,系统自动构建反向传播部分的计算图
train_op = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) saver = tf.train.Saver()
if is_train:
with tf.Session() as sess:
writer = tf.summary.FileWriter('logfile', sess.graph)
# 初始化所有变量,必须最先执行
tf.global_variables_initializer().run()
# 以下为训练迭代,迭代10轮
for epoch in range(10):
total_loss = 0
for i in range(len(X_train)):
_, loss = sess.run([train_op, cost], feed_dict={X:[X_train[i]], Y:[Y_train[i]]})
total_loss += loss
print('Epoch: %04d, total loss=%.9f' % (epoch + 1, total_loss))
# 保存model
if (epoch + 1) % train_step == 0:
save_path = saver.save(sess, './model/model.ckpt', global_step=epoch + 1)
print('Training complete!')
pred = sess.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy)
else:
# 恢复model,继续训练
with tf.Session() as sess1:
# 从'checkpoint'文件中读出最新存档的路径
ckpt = tf.train.get_checkpoint_state('./model')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess1, ckpt.model_checkpoint_path)
print('restore model sucess!')
else:
sys(0)
print('continue train …………')
for epoch in range(epochs):
total_loss = 0
for i in range(len(X_train)):
_, loss = sess1.run([train_op, cost], feed_dict={X:[X_train[i]], Y:[Y_train[i]]})
total_loss += loss
print('Epoch: %04d, total loss=%.9f' % (epoch + 1, total_loss))
# 保存model
if (epoch + 1) % train_step == 0:
save_path = saver.save(sess1, './model/model.ckpt', global_step=epoch + 1)
print('Training complete!')
pred = sess1.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy) # 恢复model参数
with tf.Session() as sess2:
# 从'checkpoint'文件中读出最新存档的路径
print('restore lastest model, compute Accuracy!')
ckpt = tf.train.get_checkpoint_state('./model')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess2, ckpt.model_checkpoint_path)
pred = sess2.run(y_pred, feed_dict={X: X_test})
# np.argmax的axis=1表示第2轴最大值的索引(这里表示列与列对比,最大值的索引)
correct = np.equal(np.argmax(pred, axis=1), np.argmax(Y_test, axis=1))
accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set: %.9f" % accuracy)

TensorFlow自带的可视化工具TensorBoard

在当前目录的命令行下键入:tensorboard --logdir=logfile

根据命令行的提示,在浏览器里输入相应的网址。

TensorFlow入门-Tianic数据集训练的更多相关文章

  1. 搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型

    原文地址:搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型 0x00 环境 OS: Ubuntu 1810 x64 Anaconda: 4.6.12 P ...

  2. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  3. FaceRank,最有趣的 TensorFlow 入门实战项目

    FaceRank,最有趣的 TensorFlow 入门实战项目 TensorFlow 从观望到入门! https://github.com/fendouai/FaceRank 最有趣? 机器学习是不是 ...

  4. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  5. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  6. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  7. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  8. 分享《机器学习实战基于Scikit-Learn和TensorFlow》中英文PDF源代码+《深度学习之TensorFlow入门原理与进阶实战》PDF+源代码

    下载:https://pan.baidu.com/s/1qKaDd9PSUUGbBQNB3tkDzw <机器学习实战:基于Scikit-Learn和TensorFlow>高清中文版PDF+ ...

  9. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

随机推荐

  1. juc并发工具类之CountDownLatch闭锁

    import java.util.concurrent.CountDownLatch; /** * 闭锁: 在进行某些运算时, 只有其他所有线程的运算全部完成,当前运算才继续执行(程序流中加了一道栅栏 ...

  2. 大数阶乘 nyoj

    大数阶乘 时间限制:3000 ms  |  内存限制:65535 KB 难度:3   描述 我们都知道如何计算一个数的阶乘,可是,如果这个数很大呢,我们该如何去计算它并输出它?   输入 输入一个整数 ...

  3. opencv中的更通用的形态学

    为了处理更为复杂的情况,opencv中还支持更多的形态学变换. 形态学名称 操作过程 操作名称 是否需要temp参数 开操作 open open(src)=先腐蚀,后膨胀  CV_MOP_OPEN 否 ...

  4. WCF揭秘学习笔记(5):WF定制活动

    WF(Windows Workflow Foundation,Windows工作流基础)为.NET提供了一种基于模型的.声明方式的过程执行引擎,它改变了传统的通过一行行编写代码来开发服务功能的方式. ...

  5. 启动ECLIPSE时,提示找到不 eclipse\jre\bin\javaw.exe

    原因:在PATH中未配置 jre\bin目录 %JAVA_HOME%\jre\bin - 无论是用:D:\Asoft\Java\jdk1.7.0_45\jre\bin 还是:D:\Asoft\Java ...

  6. vue 之radio绑定v-model

    示例: 单选radio <label ><input type="radio" value="0" v-model="branch& ...

  7. BASIC-10_蓝桥杯_十进制转十六进制

    示例代码: #include <stdio.h>#define N 16 void dg(int a){ int y = a%N; int next = (a-y)/N; if (next ...

  8. BASIC-1_蓝桥杯_闰年判断

    正确代码: #include <stdio.h> int main(void){ int year = 0 ; scanf("%d",&year); if (y ...

  9. 杂项:mPaaS

    ylbtech-杂项:mPaaS 1. 概述返回顶部 mPaaS 是源于支付宝 App 的移动开发平台,为移动开发.测试.运营及运维提供云到端的一站式解决方案,能有效降低技术门槛.减少研发成本.提升开 ...

  10. 关于Oracle与MySQL的使用总结

    平时使用的比较多的数据库管理系统就是Oracle和MySQL,我在这里记录下使用过程中的遇到的问题以及解决方案,以备不时之需 Oracle 关于表空间 Oracle创建数据的代价还是比较大的,所以使用 ...