import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
import torchvision
from mpl_toolkits.mplot3d import Axes3D #画3D图
from matplotlib import cm
# Hyper Parameters
EPOCH=10
BATCH_SIZE=64
LR = 0.005 # learning rate
DOWNLOAD_MNIST=False
N_TEST_IMG=5 train_data=torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST
) train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True) class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__() self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.Tanh(),
nn.Linear(128,64),
nn.Tanh(),
nn.Linear(64, 12),
# nn.Tanh(),
# nn.Linear(12, 3),
)
self.decoder=nn.Sequential(
# nn.Linear(3,12),
# nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid() ) def forward(self, x ):
encoder=self.encoder(x)
decoder=self.decoder(encoder)
return encoder,decoder AutoEncoder = AutoEncoder()
# print(AutoEncoder) optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.MSELoss() f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2)) plt.ion() # continuously plot view_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255 for i in range(N_TEST_IMG):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
a[0][i].set_xticks(())
a[0][i].set_yticks(()) for epoch in range(EPOCH):
for step,(x,b_label) in enumerate(train_loader):
b_x=x.view(-1,28*28)
b_y=x.view(-1,28*28)
encoded, decoded = AutoEncoder(b_x)
loss=loss_func(decoded,b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%100==0:
print('Epoch:|',epoch,'train loss:%0.4f'%loss.data.numpy())
_,decoded_data=AutoEncoder(view_data)
for i in range(N_TEST_IMG):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap='gray')
a[1][i].set_xticks(())
a[1][i].set_yticks(())
plt.draw()
plt.pause(0.05)
plt.ioff()
plt.show() view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255
encoded_data,_=AutoEncoder(view_data)
fig=plt.figure(2)
ax=Axes3D(fig)
X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values=train_data.train_labels[:200].numpy()
for x,y,z ,s in zip(X,Y,Z,values):
c=cm.rainbow(int(255*s/9))
ax.text(x,y,z,s,backgroundcolor=c)
ax.set_xlim(X.min(),X.max())
ax.set_ylim(Y.min(),Y.max())
ax.set_zlim(Z.min(),Z.max())
plt.show()

选出五张图片做测试。

图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。

encode与decode的更多相关文章

  1. [LeetCode] Encode and Decode Strings 加码解码字符串

    Design an algorithm to encode a list of strings to a string. The encoded string is then sent over th ...

  2. 【python】python新手必碰到的问题---encode与decode,中文乱码[转]

    转自:http://blog.csdn.net/a921800467b/article/details/8579510 为什么会报错“UnicodeEncodeError:'ascii' codec ...

  3. LeetCode Encode and Decode Strings

    原题链接在这里:https://leetcode.com/problems/encode-and-decode-strings/ 题目: Design an algorithm to encode a ...

  4. Encode and Decode Strings

    Design an algorithm to encode a list of strings to a string. The encoded string is then sent over th ...

  5. encode和decode

    Python字符串的encode与decode研究心得乱码问题解决方法 为什么会报错“UnicodeEncodeError: 'ascii' codec can't encode characters ...

  6. 【python】浅谈encode和decode

    对于encode和decode,笔者也是根据自己的理解,有不对的地方还请多多指点. 编码的理解: 1.编码:utf-8,utf-16,gbk,gb2312,gb18030等,编码为了便于理解,可以把它 ...

  7. 271. Encode and Decode Strings

    题目: Design an algorithm to encode a list of strings to a string. The encoded string is then sent ove ...

  8. [LeetCode#271] Encode and Decode Strings

    Problem: Design an algorithm to encode a list of strings to a string. The encoded string is then sen ...

  9. Encode and Decode Strings 解答

    Question Design an algorithm to encode a list of strings to a string. The encoded string is then sen ...

  10. python encode和decode函数说明【转载】

    python encode和decode函数说明 字符串编码常用类型:utf-8,gb2312,cp936,gbk等. python中,我们使用decode()和encode()来进行解码和编码 在p ...

随机推荐

  1. subgradients

    目录 定义 上镜图解释 次梯度的存在性 性质 极值 非负数乘 \(\alpha f(x)\) 和,积分,期望 仿射变换 仿梯度 混合函数 应用 Pointwise maximum 上确界 suprem ...

  2. 判断一个点是否在某个区域内。百度,高德,腾讯都能用。(php版)

    <?php // *** 配置文件(表示区域的三维数组)其内的点,必须按顺时针方向依次给出! $area = array( // 天通苑店 0 => array( array('x'=&g ...

  3. Flutter获取屏幕宽高和Widget大小

    我们平时在开发中的过程中通常都会获取屏幕或者 widget 的宽高用来做一些事情,在 Flutter 中,我们可以使用如下方法来获取屏幕或者 widget 的宽高. MediaQuery 一般情况下, ...

  4. react-navigation使用之嵌套和跳转

    1. 新版react-native已经将react-navigation作为官方版本发布,基础Demo可以从官方网站获得,比较困扰的问题是组件的嵌套和第二.第三页面的跳转. 2. 组件嵌套问题: 要在 ...

  5. bzoj 3282: Tree (Link Cut Tree)

    链接:https://www.lydsy.com/JudgeOnline/problem.php?id=3282 题面: 3282: Tree Time Limit: 30 Sec  Memory L ...

  6. 钉钉相关功能介入开发系列一:获取access_token

    获取access_token的基本代码,与微信不同的是钉钉的token正常情况下有效期为7200秒,有效期内重复获取返回相同结果,并自动续期,比微信方便多了 //基本信息 string appkey ...

  7. Codeforces 1093D Beautiful Graph(二分图染色+计数)

    题目链接:Beautiful Graph 题意:给定一张无向无权图,每个顶点可以赋值1,2,3,现要求相邻节点一奇一偶,求符合要求的图的个数. 题解:由于一奇一偶,需二分图判定,染色.判定失败,直接输 ...

  8. 浅析redis缓存 在spring中的配置 及其简单的使用

    一:如果你需要在你的本地项目中配置redis.那么你首先得需要在你的本地安装redis 参考链接[http://www.runoob.com/redis/redis-install.html] 下载r ...

  9. Tomcat 配置文件 server.xml

    Tomcat隶属于Apache基金会,是开源的轻量级Web应用服务器,使用非常广泛.server.xml是Tomcat中最重要的配置文件,server.xml的每一个元素都对应了Tomcat中的一个组 ...

  10. App测试全(转自鲁德)

    1.App测试流程 1.1流程图 1.2测试周期 测试周期可按项目的开发周期来确定测试时间,一般测试时间为两三周(即15个工作日),根据项目情况以及版本质量可适当缩短或延长测试时间. 1.3测试资源 ...