'''-------------一、SE模块-----------------------------''' import torch from torch import nn # 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid class SE_Block(nn.Module): def __init__(self, inchannel, ratio=16): super(SE_Block, self).__init__() # 全局平均池化(Fsq操作) self.gap = nn.AdaptiveAvgPool2d((1, 1)) # 两个全连接层(Fex操作) self.fc = nn.Sequential( nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r nn.ReLU(), nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c nn.Sigmoid() ) def forward(self, x): # 读取批数据图片数量及通道数 b, c, h, w = x.size() # Fsq操作:经池化后输出b*c的矩阵 y = self.gap(x).view(b, c) # Fex操作:经全连接层输出(b,c,1,1)矩阵 y = self.fc(y).view(b, c, 1, 1) # Fscale操作:将得到的权重乘以原来的特征图x return x * y.expand_as(x) if __name__ == '__main__': input = torch.randn(1, 64, 32, 32) seblock = SE_Block(64) print(seblock) output = seblock(input) print(input.shape) print(output.shape)