生成对抗网络(Generative Adversarial Network, GAN)是一种通过对抗训练生成数据的深度学习模型,由生成器(Generator)和判别器(Discriminator)两部分组成,其核心思想源于博弈论中的零和博弈。

一、核心组成

生成器(G)

  目标:生成逼真的假数据(如图像、文本),试图欺骗判别器。

  输入:随机噪声(通常服从高斯分布或均匀分布)。

  输出:合成数据(如假图像)。

判别器(D)

  目标:区分真实数据(来自训练集)和生成器合成的假数据。

  输出:概率值(0到1),表示输入数据是真实的概率。

二、关于对抗训练

1. 动态博弈

  1)生成器尝试生成越来越逼真的数据,使得判别器无法区分真假。

  2)判别器则不断优化自身,以更准确地区分真假数据。

  3)两者交替训练,最终达到纳什均衡(生成器生成的数据与真实数据分布一致,判别器无法区分,输出概率恒为0.5)。

2. 优化目标(极小极大博弈)

\[\min_{G}{\max_D}V(D,G)=E_{x\sim p_{data}}[logD(x)]+E_{z\sim p_z}[log(1-D(G(z)))]
\]

  其中,

    \(D(x)\):判别器对真实数据的判别结果;

    \(G(z)\):生成器生成的假数据;

    判别器希望最大化\(V(D,G)\)(正确分类真假数据);

    生成器希望最小化\(V(D,G)\)(让判别器无法区分)。

3.交替更新

1) 固定生成器,训练判别器:

  用真实数据(标签1)和生成数据(标签0)训练判别器,提高其鉴别能力。

2) 固定判别器,训练生成器:

  通过反向传播调整生成器参数,使得判别器对生成数据的输出概率接近1(即欺骗判别器)。

三、典型应用

  图像生成:生成逼真的人脸、风景、艺术画(如 DCGAN、StyleGAN);

  图像编辑:图像修复(填补缺失区域)、风格迁移(如将照片转为油画风格);

  数据增强:为小样本任务生成额外的训练数据;

  超分辨率重建:将低分辨率图像恢复为高分辨率图像。

四、优势与挑战

优势

  无监督学习:无需对数据进行标注,仅通过真实数据即可训练(适用于标注成本高的场景)。

  生成高质量数据:相比其他生成模型(如变分自编码器 VAE),GAN 在图像生成等任务中往往能生成更逼真、细节更丰富的数据。

  灵活性:生成器和判别器可以采用不同的网络结构(如卷积神经网络 CNN、循环神经网络 RNN 等),适用于多种数据类型(图像、文本、音频等)。

挑战

  训练不稳定:容易出现 “模式崩溃”(生成器只生成少数几种相似数据,缺乏多样性)或难以收敛;

  平衡难题:生成器和判别器的能力需要匹配,否则可能一方过强导致另一方无法学习(如判别器太弱,生成器无需优化即可欺骗它);

  可解释性差:生成器的内部工作机制难以解释,生成结果的可控性较弱(近年通过改进模型如 StyleGAN 缓解了这一问题)。

五、Python示例

  使用 PyTorch 实现简单 的GAN 模型,生成手写数字图像。

import matplotlib
matplotlib.use('TkAgg') import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np plt.rcParams['font.sans-serif']=['SimHei'] # 中文支持
plt.rcParams['axes.unicode_minus']=False # 负号显示 # 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42) # 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 数据加载和预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 将图像归一化到 [-1, 1]
]) train_dataset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 定义生成器网络
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_dim=784):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, img_dim),
nn.Tanh() # 输出范围 [-1, 1]
) def forward(self, z):
return self.model(z).view(z.size(0), 1, 28, 28) # 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, img_dim=784):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid() # 输出概率值
) def forward(self, img):
img_flat = img.view(img.size(0), -1)
return self.model(img_flat) # 初始化模型
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device) # 定义损失函数和优化器
criterion = nn.BCELoss()
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) # 训练函数
def train_gan(epochs):
for epoch in range(epochs):
for i, (real_imgs, _) in enumerate(train_loader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device) # 创建标签
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) # ---------------------
# 训练判别器
# ---------------------
d_optimizer.zero_grad() # 计算判别器对真实图像的损失
real_pred = discriminator(real_imgs)
d_real_loss = criterion(real_pred, real_labels) # 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z) # 计算判别器对假图像的损失
fake_pred = discriminator(fake_imgs.detach())
d_fake_loss = criterion(fake_pred, fake_labels) # 总判别器损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step() # ---------------------
# 训练生成器
# ---------------------
g_optimizer.zero_grad() # 生成假图像
fake_imgs = generator(z)
# 计算判别器对假图像的预测
fake_pred = discriminator(fake_imgs)
# 生成器希望判别器将假图像判断为真
g_loss = criterion(fake_pred, real_labels)
g_loss.backward()
g_optimizer.step() # 打印训练进度
if i % 100 == 0:
print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} "
f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}") # 每个epoch结束后,生成一些样本图像
if (epoch + 1) % 10 == 0:
generate_samples(generator, epoch + 1, latent_dim, device) # 生成样本图像
def generate_samples(generator, epoch, latent_dim, device, n_samples=16):
generator.eval()
z = torch.randn(n_samples, latent_dim).to(device)
with torch.no_grad():
samples = generator(z).cpu() # 可视化生成的样本
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flatten()):
ax.imshow(samples[i][0].numpy(), cmap='gray')
ax.axis('off') plt.tight_layout()
plt.savefig(f"gan_samples/gan_samples_epoch_{epoch}.png")
plt.close()
generator.train() # 训练模型
train_gan(epochs=50) # 生成最终样本
generate_samples(generator, "final", latent_dim, device)

最终生成的样本:

六、小结

  GAN通过对抗机制实现了强大的生成能力,成为生成模型领域的里程碑技术。衍生变体(如CGAN、CycleGAN等)进一步扩展了其应用场景。

End.

生成式对抗网络(GAN)模型原理概述的更多相关文章

  1. 生成式对抗网络GAN 的研究进展与展望

    生成式对抗网络GAN的研究进展与展望.pdf 摘要: 生成式对抗网络GAN (Generative adversarial networks) 目前已经成为人工智能学界一个热门的研究方向. GAN的基 ...

  2. 0901-生成对抗网络GAN的原理简介

    0901-生成对抗网络GAN的原理简介 目录 一.GAN 概述 二.GAN 的网络结构 三.通过一个举例具体化 GAN 四.GAN 的设计细节 pytorch完整教程目录:https://www.cn ...

  3. 【CV论文阅读】生成式对抗网络GAN

    生成式对抗网络GAN 1.  基本GAN 在论文<Generative Adversarial Nets>提出的GAN是最原始的框架,可以看成极大极小博弈的过程,因此称为“对抗网络”.一般 ...

  4. 生成式对抗网络(GAN)实战——书法字体生成练习赛

    https://www.tinymind.cn/competitions/ai 生成式对抗网络(GAN)是近年来大热的深度学习模型. 目前GAN最常使用的场景就是图像生成,作为一种优秀的生成式模型,G ...

  5. 【神经网络与深度学习】生成式对抗网络GAN研究进展(五)——Deep Convolutional Generative Adversarial Nerworks,DCGAN

    [前言]      本文首先介绍生成式模型,然后着重梳理生成式模型(Generative Models)中生成对抗网络(Generative Adversarial Network)的研究与发展.作者 ...

  6. 【机器学习】李宏毅——生成式对抗网络GAN

    1.基本概念介绍 1.1.What is Generator 在之前我们的网络架构中,都是对于输入x得到输出y,只要输入x是一样的,那么得到的输出y就是一样的. 但是Generator不一样,它最大的 ...

  7. Keras入门——(3)生成式对抗网络GAN

    导入 matplotlib 模块: import matplotlib 查看自己版本所支持的backends: print(matplotlib.rcsetup.all_backends) 返回信息: ...

  8. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  9. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  10. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

随机推荐

  1. FREERTOS指令测试的思考

    采用freertos编程后,有必要增加指令测试的功能:   1.便于对于软件各个子模块或功能点进行测试.   2.便于对硬件各个功能点是否正常进行测试. 在裸机编程中,我们一般使用串口进行运行数据的打 ...

  2. spring-boot静态资源目录配置

    spring-boot静态资源目录配置(配置js.css.图片等资源的位置) spring-boot静态资源默认为/src/main/resources下的/static目录,可以通过applicat ...

  3. Rust实战系列-生命周期、所有权和借用

    本文是<Rust in action>学习总结系列的第四部分,更多内容请看已发布文章: 一.Rust实战系列-Rust介绍 二.Rust实战系列-基本语法 三.Rust实战系列-复合数据类 ...

  4. 解决C盘根目录不能创建文件,只能创建文件夹问题

    转载:https://blog.csdn.net/xinke453/article/details/7496545 解决方法 用管理员运行cmd 输入 icacls c:\ /setintegrity ...

  5. ASP.NET Core Web API中操作方法中的参数来源

    在ASP.NET Core Web API中,有多种方式可以传递参数给操作方法.以下是一些常见的参数传递方式: 路由参数(Route Parameters):参数值从URL的路由中提取. // Rou ...

  6. Ubuntu Vmware虚拟机 没有共享文件夹/mnt/hgfs 解决

    问题现象 在Vmware虚拟机设置共享文件夹后,在Ubuntu系统侧应该在/mnt/hgfs目录下可见.然而,有时在重启虚拟机后不存在该文件夹. 解决方法 在终端中执行以下代码,需要root权限. s ...

  7. 微软开源bitnet b1.58大模型,应用效果测评(问答、知识、数学、逻辑、分析)

    微软开源bitnet b1.58大模型,应用效果测评(问答.知识.数学.逻辑.分析) 目       录 1.     前言... 2 2.     应用部署... 2 3.     应用效果... ...

  8. 保姆式Win11安装教程|Rufus工具制作U盘+绕过限制+驱动安装全解析(附资源包)

    Windows 11 简介 Windows 11是微软推出的全新一代操作系统,以直观交互和AI技术为核心升级.其界面采用圆角设计和居中任务栏布局,支持多窗口贴靠分屏与虚拟桌面功能,提升多任务处理效率. ...

  9. git基础及gitee配置

    安装git 网址:https://git-scm.com/book/zh/v2/起步-安装-Git 使用git 基本指令 # 初始化指令 git init # 管理目录下的文件状态 注:新增文件和修改 ...

  10. 【译】Visual Studio 扩展管理器更新

    Visual Studio 2022 的最新更新引入了专门的设计用于改进扩展管理方式的功能.这些更新提供的工具可以帮助您自动化过程,为配置提供详细的控制,并增强用户界面以简化您的开发工作流程. 无缝自 ...