38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
|
'''-------------一、SE模块-----------------------------'''
|
|||
|
import torch
|
|||
|
from torch import nn
|
|||
|
|
|||
|
|
|||
|
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
|||
|
class SE_Block(nn.Module):
|
|||
|
def __init__(self, inchannel, ratio=16):
|
|||
|
super(SE_Block, self).__init__()
|
|||
|
# 全局平均池化(Fsq操作)
|
|||
|
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
|||
|
# 两个全连接层(Fex操作)
|
|||
|
self.fc = nn.Sequential(
|
|||
|
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
|||
|
nn.ReLU(),
|
|||
|
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
|||
|
nn.Sigmoid()
|
|||
|
)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
# 读取批数据图片数量及通道数
|
|||
|
b, c, h, w = x.size()
|
|||
|
# Fsq操作:经池化后输出b*c的矩阵
|
|||
|
y = self.gap(x).view(b, c)
|
|||
|
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
|||
|
y = self.fc(y).view(b, c, 1, 1)
|
|||
|
# Fscale操作:将得到的权重乘以原来的特征图x
|
|||
|
return x * y.expand_as(x)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
input = torch.randn(1, 64, 32, 32)
|
|||
|
seblock = SE_Block(64)
|
|||
|
print(seblock)
|
|||
|
output = seblock(input)
|
|||
|
print(input.shape)
|
|||
|
print(output.shape)
|
|||
|
|