927d049f13
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块 - 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等 - 可以根据需求选择合适的模块来提升模型性能
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
'''-------------一、SE模块-----------------------------'''
|
||
import torch
|
||
from torch import nn
|
||
|
||
|
||
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
||
class SE_Block(nn.Module):
|
||
def __init__(self, inchannel, ratio=16):
|
||
super(SE_Block, self).__init__()
|
||
# 全局平均池化(Fsq操作)
|
||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||
# 两个全连接层(Fex操作)
|
||
self.fc = nn.Sequential(
|
||
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
||
nn.ReLU(),
|
||
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
def forward(self, x):
|
||
# 读取批数据图片数量及通道数
|
||
b, c, h, w = x.size()
|
||
# Fsq操作:经池化后输出b*c的矩阵
|
||
y = self.gap(x).view(b, c)
|
||
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
||
y = self.fc(y).view(b, c, 1, 1)
|
||
# Fscale操作:将得到的权重乘以原来的特征图x
|
||
return x * y.expand_as(x)
|
||
|
||
if __name__ == '__main__':
|
||
input = torch.randn(1, 64, 32, 32)
|
||
seblock = SE_Block(64)
|
||
print(seblock)
|
||
output = seblock(input)
|
||
print(input.shape)
|
||
print(output.shape)
|
||
|