pfcfuse/componets/SEBlock.py
whai 30bbfdf86e feat(components): 添加 DEConv 和 SEBlock 组件
- 新增 DEConv 组件,用于细节增强卷积
- 新增 SEBlock组件,用于通道注意力机制
- 更新 net.py 中的 DetailNode 结构
- 调整 train.py 中的模型初始化
2024-11-14 16:59:11 +08:00

38 lines
1.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''-------------一、SE模块-----------------------------'''
import torch
from torch import nn
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
def __init__(self, inchannel, ratio=16):
super(SE_Block, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作经全连接层输出bc11矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
seblock = SE_Block(64)
print(seblock)
output = seblock(input)
print(input.shape)
print(output.shape)