为了账号安全,请及时绑定邮箱和手机立即绑定

【2024年】第2天 神经网络中的激活函数2(pytorch)

标签:
Python

1. 激活函数的多种形式

  • 在pytorch中,每个激活函数都有两种形式,类形式和函数形式。
  • 类形式在torch.nn模块中定义。在使用时,需要先对其实例化才能应用。
  • 函数形式在torch.nn.functional模块中定义。在使用时,可以直接以函数调用的方式进行。
  • 以激活函数tanh为例,定义激活函数:
  1. 以类形式使用,在模型类的init()方法中,定义激活函数
# 对tanh类进行实例化
self.tanh = torch.nn.tanh()

接着便可以在模型类的forward()方法中,添加激活函数的应用

# 应用tanh类的实例化对象
output = self.tanh(input)
  1. 以类的形式直接使用:还可以将1中的操作统一在模型类中的forward()方法中完成。
output = torch.nn.tanh()(input)
  1. 以函数的形式使用
    在模型类的forward()方法中,直接调用激活函数的方式。
output = torch.nn.functional.tanh(input)

在以函数的形式使用激活函数时,该激活函数不会驻留在模型类的内存里,会与其他的pytorch库函数一样,在全局内存中被调用。
在torch.nn.functional中激活函数的命名都是小写形式。

2. Swish函数

  • 好的激活函数可以对特征数据的激活更加精准,能够提高模型的精度。
  • 目前,业界公认好的激活函数为Swish和Mish。
  • 在保持结构不变的基础上,直接将模型中的其他激活函数换成Swish或Mish激活函数,都会使模型的精度有所提升。
import numpy as np
import matplotlib.pyplot as plt

# 定义Swish函数
def swish(x, beta=1.0):
    return x * (1.0 / (1.0 + np.exp(-beta*x)))

# 定义Mish函数
def mish(x):
    return x * np.tanh(np.log(1 + np.exp(x)))

x = np.linspace(-5, 5, 100)
y_swish = swish(x)
y_mish = mish(x)

plt.plot(x, y_swish, label='Swish')
plt.plot(x, y_mish, label='Mish')
plt.title('Swish VS Mish Activation Functions')
plt.xlabel('x')
plt.ylabel('Activation Function Output')
plt.legend()
plt.grid(True)
plt.show()

运行结果:
图片描述

  • 二者的曲线非常的相似,如上所示Mish比Swish更胜一筹。
  • Swish是谷歌公司发现一个效果更优于ReLU的激活函数。
  • 在测试中,保持所有的模型参数参数不变,只是把原来模型中的ReLU激活函数修改为Swish激活函数,模型的准确率均由提升。公式如下所示:

f(x) = xSigmoid(βx)

  • 其中β为x的缩放参数,一般情况下默认值1即可。
  • 在使用了批量归一化算法的情况下,还需要对x的缩放值β进行调节。
  • 在实际应用中,β参数可以是常数,由手动调节,也可以是可训练的参数,由神经网络自己学习。

3. Mish激活函数

  • Mish激活函数从Swish中获得"灵感",也是用输入变量与其非线性变化后的激活函数相乘。
  • 其中,将非线性变化部分的缩放参数β用Softplus激活函数来代替,使其无须输入任何标量就可以更改网络参数,其公式为:

f(x) = x tanh(Softplus(x))

  • 将Softplus的公式代入,也可以把上式写为

f(x) = x tanh(In(1+ex))

  • 相比Swish,Mish激活函数没有了参数,使用起来更方便。
import torch
import torch.nn as nn
import torch.nn.functional as F

# Swish激活函数
def swish(x, beta=1):
    return x * torch.nn.Sigmoid()(x*beta)

# Mish激活函数
def mish(x):
    return x * (torch.tanh(F.softplus))

class Mish(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * (torch.tanh(F.softplus))

4.激活函数的总结

  • 在神经网络中,运算特征不断进行循环计算,因此,在每次循环过程中,每个神经元的值也是在不断变化的,特征间的差距会再循环过程中被不断地方法,当输入数据本身差别较大时,用tanh会好一点。
  • 当输入数据本身差别不大时,用Sigmoid效果就会更好一些。
  • 而后来出现的ReLU激活函数,主要优势是能够生成稀疏性更好的特征数据,即将数据转化为只有最大数值,其他都为0的特征。
  • 这种变换可以更好地突出输入特征,用大多数元素为0的稀疏矩阵来实现。
  • Swish激活函数和Mish激活函数是ReLU基础上进一步优化产生的,在深层神经网络中效果更加明显。
  • Mish激活函数会比Swish激活函数还要好一些。

5. 训练模型的步骤与方法

  1. 将样本数据输入模型算出正向的结果。
  2. 计算模型结果与样本的目标标签之间的差距(也称为损失值,即loss)
  3. 根据损失值,使用链式反向求导,依次计算模型中每个参数(即权重)的梯度。
  4. 使用优化器中的策略对模型中的参数进行更新。
点击查看更多内容
1人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消