Breast Cancer on PyTorch

Code

# encoding:utf8

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = nn.Linear(30, 60)
self.a1 = nn.Sigmoid()
self.l2 = nn.Linear(60, 2)
self.a2 = nn.ReLU()
self.l3 = nn.Softmax(dim=1) def forward(self, x):
x = self.l1(x)
x = self.a1(x)
x = self.l2(x)
x = self.a2(x)
x = self.l3(x)
return x if __name__ == '__main__':
breast_cancer = load_breast_cancer() x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long) net = Net() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005) # PyTorch suit to tiny learning rate error = list() for epoch in range(250):
optimizer.zero_grad()
y_pred = net(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
error.append(loss.item()) y_pred = net(x_test)
y_pred = torch.argmax(y_pred, dim=1) # it is necessary that drawing the loss plot when we fine tuning the model
plt.plot(np.arange(1, len(error)+1), error)
plt.show() print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))

损失函数图像:

nn.Sequential

# encoding:utf8

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np if __name__ == '__main__':
breast_cancer = load_breast_cancer() x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long) net = nn.Sequential(
nn.Linear(30, 60),
nn.Sigmoid(),
nn.Linear(60, 2),
nn.ReLU(),
nn.Softmax(dim=1)
) criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005) # PyTorch suit to tiny learning rate error = list() for epoch in range(250):
optimizer.zero_grad()
y_pred = net(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
error.append(loss.item()) y_pred = net(x_test)
y_pred = torch.argmax(y_pred, dim=1) # it is necessary that drawing the loss plot when we fine tuning the model
plt.plot(np.arange(1, len(error)+1), error)
plt.show() print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))

模型性能:

              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        14
versicolor 1.00 1.00 1.00 16
virginica 1.00 1.00 1.00 20 accuracy 1.00 50
macro avg 1.00 1.00 1.00 50
weighted avg 1.00 1.00 1.00 50

Iris Classification on PyTorch的更多相关文章

  1. Iris Classification on Tensorflow

    Iris Classification on Tensorflow Neural Network formula derivation \[ \begin{align} a & = x \cd ...

  2. Iris Classification on Keras

    Iris Classification on Keras Installation Python3 版本为 3.6.4 : : Anaconda conda install tensorflow==1 ...

  3. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  4. Pytorch collate_fn用法

    By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first ...

  5. pytorch和tensorflow的爱恨情仇之定义可训练的参数

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...

  6. pytorch下对简单的数据进行分类(classification)

    看了Movan大佬的文字教程让我对pytorch的基本使用有了一定的了解,下面简单介绍一下二分类用pytorch的基本实现! 希望详细的注释能够对像我一样刚入门的新手来说有点帮助! import to ...

  7. pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

    论文  < Convolutional Neural Networks for Sentence Classification>通过CNN实现了文本分类. 论文地址: 666666 模型图 ...

  8. pytorch之 classification

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  9. pytorch 5 classification 分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

随机推荐

  1. JQuery 获取页面某一元素在屏幕上的位置

    获取页面某一元素的绝对X,Y坐标 var X = $('#ElementID').offset().top;//元素在当前视窗距离顶部的位置 var Y = $('#ElementID').offse ...

  2. centos7 cpanm安装,及perl模块安装

    1. cpan安装 yum安装 yum install perl-App-cpanminus.noarch 注意:安装完成后,root及非root用户都可以使用cpanm安装模块,root用户直接用c ...

  3. 关于如何利用计算属性进行button的控制

    element分页没用它的 (这个只要上一页下一页),比如共2页的时候,你在第一页,你肯定可以点击下一页,当你进入到第二页的时候这个button肯定就不能点击了啊,它的属性diaabled=true让 ...

  4. node.js初识09

    1.node_module文件夹 如果你的require中没有写./,那么Node.js将该文件视为node_modules目录下的一个文件. 2.package.json文件 如果使用文件夹来统筹管 ...

  5. javaweb之验证码验证技术

    今天学习了一个验证码校验技术,所以就写下了一些笔记,方便日后查看.首先创建web工程 1.然后在src目录下创建一个Servlet类,此类用来显示登录页面和错误信息提示 package com.LHB ...

  6. tcl脚本

    tcl,全名tool command language,是一种通用的工具语言. 1)每个命令之间,通过换行符或者分号隔开: 2)tcl的每个命令包含一个或者多个单词,默认第一个单词表示命令,第二个单词 ...

  7. Message对象

    一)描述 1: 每一个Message对象都包含两个对象: (1)google::protobuf::Descriptor 描述对象,是Message所有Filed的一个集合,它又包含了FieldDes ...

  8. GCD (RMQ + 二分)

    RMQ存的是区间GCD,然后遍历 i: 1->n, 然后不断地对[i, R]区间进行二分求以i为起点的相同gcd的区间范围,慢慢缩减区间. #include<bits/stdc++.h&g ...

  9. uvalive 3887 Slim Span

    题意: 一棵生成树的苗条度被定义为最长边与最小边的差. 给出一个图,求其中生成树的最小苗条度. 思路: 最开始想用二分,始终想不到二分终止的条件,所以尝试暴力枚举最小边的长度,然后就AC了. 粗略估计 ...

  10. 【CDH学习之三】CDH安装

    登录CM 1.版本选择 免费版本的CM5已经去除50个节点数量的限制. 各个Agent节点正常启动后,可以在当前管理的主机列表中看到对应的节点. 选择要安装的节点,点继续. 接下来,出现以下包名,说明 ...