feat(net): 替换 SMFA 为 SCSA 并调整相关配置

- 将 SMFA 模块替换为 SCSA 模块
- 更新项目配置,使用本地 Python 3.8 环境
-调整 SCSA 模块参数,如维度、头数等
- 优化注意力机制,提高模型性能
This commit is contained in:
zjut 2024-11-08 12:04:52 +08:00
parent 41b2ea1ff9
commit e1a339e04b
6 changed files with 166 additions and 6 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" /> <orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="PyDocumentationSettings"> <component name="PyDocumentationSettings">

View File

@ -12,5 +12,5 @@
</MavenGeneralSettings> </MavenGeneralSettings>
</option> </option>
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
</project> </project>

Binary file not shown.

156
componets/SCSA.py Normal file
View File

@ -0,0 +1,156 @@
import typing as t
import torch
import torch.nn as nn
from einops.einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']
"""SCSA探索空间注意力和通道注意力之间的协同作用
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进
虽然它们的结合更有利于发挥各自的优势但通道和空间注意力之间的协同作用尚未得到充分探索缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系提出了一种新颖的空间和通道协同注意力模块SCSA我们的SCSA由两部分组成可共享的多语义空间注意力SMSA和渐进式通道自注意力PCSA
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中有效地指导通道重新校准此外PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异
我们在七个基准数据集上进行了广泛的实验包括 ImageNet-1K 上的分类MSCOCO 2017 上的对象检测ADE20K 上的分割以及其他四个复杂场景检测数据集我们的结果表明我们提出的 SCSA 不仅超越了当前最先进的注意力机制
而且在各种任务场景中表现出增强的泛化能力
"""
class SCSA(BaseModule):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
norm_cfg: t.Dict = dict(type='BN'),
act_cfg: t.Dict = dict(type='ReLU'),
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == '__main__':
block = SCSA(
dim=256,
head_num=8,
)
input_tensor = torch.rand(1, 256, 32, 32)
# 调用模块进行前向传播
output_tensor = block(input_tensor)
# 打印输入和输出张量的大小
print("Input size:", input_tensor.size())
print("Output size:", output_tensor.size())

12
net.py
View File

@ -6,7 +6,8 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.SMFA import SMFA from componets.SCSA import SCSA
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
from componets.WTConvCV2 import WTConv2d from componets.WTConvCV2 import WTConv2d
@ -159,7 +160,8 @@ class BaseFeatureExtraction(nn.Module):
self.WTConv2d = WTConv2d(dim, dim) self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = SMFA(dim=dim) # self.token_mixer = SMFA(dim=dim)
self.token_mixer = SCSA(dim=dim,head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
@ -195,9 +197,11 @@ class BaseFeatureExtraction(nn.Module):
) )
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
) )
pol = self.poolmlp(self.norm2(x))
x = wtConvX + self.drop_path( x = wtConvX + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x))) * pol)
else: else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x)))
@ -569,7 +573,7 @@ class Restormer_Decoder(nn.Module):
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), ) stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim) self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(in_channels=dim)
def forward(self, inp_img, base_feature, detail_feature): def forward(self, inp_img, base_feature, detail_feature):