Pytorch-实战之对Himmelblau函数的优化
1.Himmelblau函数
Himmelblau函数:
F(x,y)=(x²+y-11)²+(x+y²-7)²:具体优化的是,寻找一个最合适的坐标(x,y)使得F(x,y)的值最小。
函数的具体图像,如下图所示:

实现代码
import numpy as np
from matplotlib import pyplot as plt
import torch
# 定义函数
def himmelblau(x_y):
return (x_y[0] ** 2 + x_y[1] - 11) ** 2 + (x_y[0] + x_y[1] ** 2 - 7) ** 2
# 生成x轴数据列表
x = np.arange(-6, 6, 0.1)
# 生成y轴数据列表
y = np.arange(-6, 6, 0.1)
print('x,y range:', x.shape, y.shape)
# 对x,y数据进行网格化,
X, Y = np.meshgrid(x, y)
print('X,Y maps:', X.shape, Y.shape)
# 计算Z轴数据
Z = himmelblau([X, Y])
fig = plt.figure('himmelblau')
ax = fig.gca(projection='3d')
# 绘制3D图形
ax.plot_surface(X, Y, Z)
ax.view_init(60, -30)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()
if __name__ == '__main__':
# [1., 0.], [-4, 0.], [4, 0.]
# x_y存储的是坐标值(x,y),目的就是求解一个最优的x_y。
x_y = torch.tensor([0., 0.], requires_grad=True)
# 定义优化器,优化器的目标就是x_y,学习速率learningrate是0.001
optimizer = torch.optim.Adam([x_y], lr=1e-3)
for step in range(20000):
# 输入坐标,得到预测值
pred = himmelblau(x_y)
# 当网络参量进行反馈时,梯度是被积累的而不是被替换掉,所以把梯度信息清零
optimizer.zero_grad()
# 获取x坐标和y坐标的梯度信息
pred.backward()
# 调用一次.step(),就会优化一次x坐标 x'=x-learningrate*▽x
# 调用一次.step(),就会优化一次y坐标 y'=y-learningrate*▽y
optimizer.step()
if step % 2000 == 0:
print ('step {}: x_y = {}, f(x) = {}'
.format(step, x_y.tolist(), pred.item()))
输出结果
x,y range: (120,) (120,)
X,Y maps: (120, 120) (120, 120)
step 0: x_y = [0.0009999999310821295, 0.0009999999310821295], f(x) = 170.0
step 2000: x_y = [2.3331806659698486, 1.9540694952011108], f(x) = 13.730916023254395
step 4000: x_y = [2.9820079803466797, 2.0270984172821045], f(x) = 0.014858869835734367
step 6000: x_y = [2.999983549118042, 2.0000221729278564], f(x) = 1.1074007488787174e-08
step 8000: x_y = [2.9999938011169434, 2.0000083446502686], f(x) = 1.5572823031106964e-09
step 10000: x_y = [2.999997854232788, 2.000002861022949], f(x) = 1.8189894035458565e-10
step 12000: x_y = [2.9999992847442627, 2.0000009536743164], f(x) = 1.6370904631912708e-11
step 14000: x_y = [2.999999761581421, 2.000000238418579], f(x) = 1.8189894035458565e-12
step 16000: x_y = [3.0, 2.0], f(x) = 0.0
step 18000: x_y = [3.0, 2.0], f(x) = 0.0
Pytorch-实战之对Himmelblau函数的优化的更多相关文章
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- SQL Server 聚合函数算法优化技巧
Sql server聚合函数在实际工作中应对各种需求使用的还是很广泛的,对于聚合函数的优化自然也就成为了一个重点,一个程序优化的好不好直接决定了这个程序的声明周期.Sql server聚合函数对一组值 ...
- 利用函数索引优化<>
SQL> select count(*),ID from test_2 group by id; COUNT(*) ID ---------- ---------- 131072 1 11796 ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- Pytorch中randn和rand函数的用法
Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- pytorch中的学习率调整函数
参考:https://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate torch.optim.lr_scheduler提供 ...
- PyTorch实战:经典模型LeNet5实现手写体识别
在上一篇博客CNN核心概念理解中,我们以LeNet为例介绍了CNN的重要概念.在这篇博客中,我们将利用著名深度学习框架PyTorch实现LeNet5,并且利用它实现手写体字母的识别.训练数据采用经典的 ...
随机推荐
- powershell 输入命令 不执行 保留输入内容 Ctrl + C
为什么 powershell 输入命令 不执行 保留输入内容 Ctrl + C 为了解释某些命令,但是不执行 比如 我说 dc命令就是 xxxxxxx 我就先输入 xxxxxxxx然后ctrl + c ...
- epoll实现的简单服务器
#include "../wrap/wrap.h" #include <sys/epoll.h> #define SIZE 1024 #define FUCK prin ...
- 基于QGIS生产建筑物高度与遥感影像数据集
1. 概述 利用遥感影像推知建筑物高度是一经典研究,现有很多学者利用机器学习的方式,利用现有数据进行训练从而构建模型 本文旨在记述使用QGIS进行建筑物高度与遥感影像数据集的获取与制作 如果不想自己动 ...
- 《Spring6核心源码解析》已完结,涵盖IOC容器、AOP切面、AOT预编译、SpringMVC,面试杠杠的!
作者:冰河 博客:https://binghe.gitcode.host 文章汇总:https://binghe.gitcode.host/md/all/all.html 源码地址:https://g ...
- Oracle NLSSORT 拼音排序 笔画排序 部首排序
create table test(name varchar2(20)); insert into test values('中国'); insert into test values('美国'); ...
- KingbaseESV8R6中查看索引常用sql
前言 KingbaseES具有丰富的索引功能,对于运行一段时间的数据库,经常需要查看索引的使用大小,使用状态等. 尤其重复索引的存在,有时会因为索引过多而造成维护成本加大和减慢数据库的运行速度. 下面 ...
- KingbaseES V8R6 几种不同的表复制方式
前言 当数据库遇到未知问题,有时候无法入手分析,在非生产数据库或者征得客户同意获得特殊时间,需要重建表解决,下面提供了多种不同的复制表的方法,我们了解一下他们的差异. 测试 1.CREATE TABL ...
- KingbaseES 函数与存储过程内容加密
说明: 数据库系统使用过程中,有些业务功能在特殊的安全级别情况下,需要对数据库中的函数和存储过程进行加密存储,以保证数据库函数和过程的代码安全性.KingbaseES 数据库,提供了DBMS_DDL扩 ...
- 2024-03-30:用go语言,集团里有 n 名员工,他们可以完成各种各样的工作创造利润, 第 i 种工作会产生 profit[i] 的利润,它要求 group[i] 名成员共同参与, 如果成员参与
2024-03-30:用go语言,集团里有 n 名员工,他们可以完成各种各样的工作创造利润, 第 i 种工作会产生 profit[i] 的利润,它要求 group[i] 名成员共同参与, 如果成员参与 ...
- 初学STM32 CAN通信(三)
1. stm32 CAN通信标准库函数 //CAN通信初始化函数 uint8_t CAN_Init(CAN_TypeDef* CANx, CAN_InitTypeDef* CAN_InitStruct ...