pfcfuse/componets/SEBlock.py

38 lines
1.2 KiB
Python
Raw Normal View History

'''-------------一、SE模块-----------------------------'''
import torch
from torch import nn
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
def __init__(self, inchannel, ratio=16):
super(SE_Block, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作经全连接层输出bc11矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
seblock = SE_Block(64)
print(seblock)
output = seblock(input)
print(input.shape)
print(output.shape)