Breast Cancer on PyTorch

Code

# encoding:utf8

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = nn.Linear(30, 60)
self.a1 = nn.Sigmoid()
self.l2 = nn.Linear(60, 2)
self.a2 = nn.ReLU()
self.l3 = nn.Softmax(dim=1) def forward(self, x):
x = self.l1(x)
x = self.a1(x)
x = self.l2(x)
x = self.a2(x)
x = self.l3(x)
return x if __name__ == '__main__':
breast_cancer = load_breast_cancer() x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long) net = Net() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005) # PyTorch suit to tiny learning rate error = list() for epoch in range(250):
optimizer.zero_grad()
y_pred = net(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
error.append(loss.item()) y_pred = net(x_test)
y_pred = torch.argmax(y_pred, dim=1) # it is necessary that drawing the loss plot when we fine tuning the model
plt.plot(np.arange(1, len(error)+1), error)
plt.show() print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))

损失函数图像:

nn.Sequential

# encoding:utf8

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np if __name__ == '__main__':
breast_cancer = load_breast_cancer() x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long) net = nn.Sequential(
nn.Linear(30, 60),
nn.Sigmoid(),
nn.Linear(60, 2),
nn.ReLU(),
nn.Softmax(dim=1)
) criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005) # PyTorch suit to tiny learning rate error = list() for epoch in range(250):
optimizer.zero_grad()
y_pred = net(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
error.append(loss.item()) y_pred = net(x_test)
y_pred = torch.argmax(y_pred, dim=1) # it is necessary that drawing the loss plot when we fine tuning the model
plt.plot(np.arange(1, len(error)+1), error)
plt.show() print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))

模型性能:

              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        14
versicolor 1.00 1.00 1.00 16
virginica 1.00 1.00 1.00 20 accuracy 1.00 50
macro avg 1.00 1.00 1.00 50
weighted avg 1.00 1.00 1.00 50

Iris Classification on PyTorch的更多相关文章

  1. Iris Classification on Tensorflow

    Iris Classification on Tensorflow Neural Network formula derivation \[ \begin{align} a & = x \cd ...

  2. Iris Classification on Keras

    Iris Classification on Keras Installation Python3 版本为 3.6.4 : : Anaconda conda install tensorflow==1 ...

  3. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  4. Pytorch collate_fn用法

    By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first ...

  5. pytorch和tensorflow的爱恨情仇之定义可训练的参数

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...

  6. pytorch下对简单的数据进行分类(classification)

    看了Movan大佬的文字教程让我对pytorch的基本使用有了一定的了解,下面简单介绍一下二分类用pytorch的基本实现! 希望详细的注释能够对像我一样刚入门的新手来说有点帮助! import to ...

  7. pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

    论文  < Convolutional Neural Networks for Sentence Classification>通过CNN实现了文本分类. 论文地址: 666666 模型图 ...

  8. pytorch之 classification

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  9. pytorch 5 classification 分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

随机推荐

  1. Spring+SpringMVC+MyBatis整合配置

    前端控制器 web.xml <?xml version="1.0" encoding="UTF-8"?> <web-app version=& ...

  2. C# 基于Aspose.Cells的数据导出到Excel

    using Aspose.Cells;  void WriteToExcel(string filePath, List<object[]> datas, string sheetName ...

  3. JS实例4

    根据当前年的前五年后五年的年月日 <select id="nian" onclick="Bian()"></select>年 <s ...

  4. ef entity转json引起的Self referencing loop

    问题简介:前段时间做项目时,将取到的entity往Redis cache里存放时报多重引用的错误. Self referencing loop detected for property 'Check ...

  5. 并发工具CyclicBarrier

    想想一下这样一个场景,有多个人需要过河,河上有一条船,船要等待满10个人才过河,过完河后每个人又各自行动. 这里的人相当于线程,注意这里,每个线程运行到一半的时候,它就要等待一个条件,即船满过河的条件 ...

  6. [IDE] ECLIPSE取消自动更新

    eclipse自动更新的取消方法: window --> preferences --> General --> Startup and Shutdown --> 在列表中找到 ...

  7. 清华操作系统实验--80x86汇编基础

    前言 80x86架构里,因为历史原因字是16位的,因此在汇编指令中用后缀-b,-w,-l来表示操作数是字节 字 或是双字 C声明 Intel数据类型 汇编代码后缀 大小(字节) char 字节 b 1 ...

  8. Python全栈-day6-day7-字符编码和文件处理

    一.字符编码 1.编码基础 定义:人在使用计算机时,使用的是人类能够读懂的字符,使用者必须通过一张字符和数字间的相对应关系表实现人机交互,这一系列标准称为字符编码 Python应用中解决核心字符串乱码 ...

  9. 4.构造Thread对象你也许不知道的几件事

    1.Thread类对象只有在调用了start()方法之后,JVM虚拟机才会给我们创建一个真正的线程!否则就不能说是创建了线程!也就是说new Thread()之后,此时实际上在计算机底层,操作系统实际 ...

  10. IFrame session(转)

    问题场景: 在一个应用(集团门户)的某个page中, 通过IFrame的方式嵌入另一个应用(集团实时监管系统)的某个页面. 当两个应用的domain 不一样时, 在被嵌入的页面中Session失效.( ...