feat(net): 替换 SMFA 为 SCSA 并调整相关配置
- 将 SMFA 模块替换为 SCSA 模块 - 更新项目配置,使用本地 Python 3.8 环境 -调整 SCSA 模块参数,如维度、头数等 - 优化注意力机制,提高模型性能
This commit is contained in:
parent
41b2ea1ff9
commit
e1a339e04b
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" />
|
<orderEntry type="inheritedJdk" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PyDocumentationSettings">
|
<component name="PyDocumentationSettings">
|
||||||
|
@ -12,5 +12,5 @@
|
|||||||
</MavenGeneralSettings>
|
</MavenGeneralSettings>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
BIN
componets/SCSA(CV2维图像).pdf
Normal file
BIN
componets/SCSA(CV2维图像).pdf
Normal file
Binary file not shown.
156
componets/SCSA.py
Normal file
156
componets/SCSA.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import typing as t
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops.einops import rearrange
|
||||||
|
from mmengine.model import BaseModule
|
||||||
|
__all__ = ['SCSA']
|
||||||
|
|
||||||
|
"""SCSA:探索空间注意力和通道注意力之间的协同作用
|
||||||
|
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进。
|
||||||
|
虽然它们的结合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异。
|
||||||
|
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系,提出了一种新颖的空间和通道协同注意力模块(SCSA)。我们的SCSA由两部分组成:可共享的多语义空间注意力(SMSA)和渐进式通道自注意力(PCSA)。
|
||||||
|
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中,有效地指导通道重新校准。此外,PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异。
|
||||||
|
我们在七个基准数据集上进行了广泛的实验,包括 ImageNet-1K 上的分类、MSCOCO 2017 上的对象检测、ADE20K 上的分割以及其他四个复杂场景检测数据集。我们的结果表明,我们提出的 SCSA 不仅超越了当前最先进的注意力机制,
|
||||||
|
而且在各种任务场景中表现出增强的泛化能力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SCSA(BaseModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_num: int,
|
||||||
|
window_size: int = 7,
|
||||||
|
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
fuse_bn: bool = False,
|
||||||
|
norm_cfg: t.Dict = dict(type='BN'),
|
||||||
|
act_cfg: t.Dict = dict(type='ReLU'),
|
||||||
|
down_sample_mode: str = 'avg_pool',
|
||||||
|
attn_drop_ratio: float = 0.,
|
||||||
|
gate_layer: str = 'sigmoid',
|
||||||
|
):
|
||||||
|
super(SCSA, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_num = head_num
|
||||||
|
self.head_dim = dim // head_num
|
||||||
|
self.scaler = self.head_dim ** -0.5
|
||||||
|
self.group_kernel_sizes = group_kernel_sizes
|
||||||
|
self.window_size = window_size
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.fuse_bn = fuse_bn
|
||||||
|
self.down_sample_mode = down_sample_mode
|
||||||
|
|
||||||
|
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
|
||||||
|
self.group_chans = group_chans = self.dim // 4
|
||||||
|
|
||||||
|
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
|
||||||
|
padding=group_kernel_sizes[0] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
|
||||||
|
padding=group_kernel_sizes[1] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
|
||||||
|
padding=group_kernel_sizes[2] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
|
||||||
|
padding=group_kernel_sizes[3] // 2, groups=group_chans)
|
||||||
|
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
|
||||||
|
self.norm_h = nn.GroupNorm(4, dim)
|
||||||
|
self.norm_w = nn.GroupNorm(4, dim)
|
||||||
|
|
||||||
|
self.conv_d = nn.Identity()
|
||||||
|
self.norm = nn.GroupNorm(1, dim)
|
||||||
|
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
||||||
|
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
|
||||||
|
|
||||||
|
if window_size == -1:
|
||||||
|
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
else:
|
||||||
|
if down_sample_mode == 'recombination':
|
||||||
|
self.down_func = self.space_to_chans
|
||||||
|
# dimensionality reduction
|
||||||
|
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
|
||||||
|
elif down_sample_mode == 'avg_pool':
|
||||||
|
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
||||||
|
elif down_sample_mode == 'max_pool':
|
||||||
|
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
The dim of x is (B, C, H, W)
|
||||||
|
"""
|
||||||
|
# Spatial attention priority calculation
|
||||||
|
b, c, h_, w_ = x.size()
|
||||||
|
# (B, C, H)
|
||||||
|
x_h = x.mean(dim=3)
|
||||||
|
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
|
||||||
|
# (B, C, W)
|
||||||
|
x_w = x.mean(dim=2)
|
||||||
|
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
|
||||||
|
|
||||||
|
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
|
||||||
|
self.local_dwc(l_x_h),
|
||||||
|
self.global_dwc_s(g_x_h_s),
|
||||||
|
self.global_dwc_m(g_x_h_m),
|
||||||
|
self.global_dwc_l(g_x_h_l),
|
||||||
|
), dim=1)))
|
||||||
|
x_h_attn = x_h_attn.view(b, c, h_, 1)
|
||||||
|
|
||||||
|
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
|
||||||
|
self.local_dwc(l_x_w),
|
||||||
|
self.global_dwc_s(g_x_w_s),
|
||||||
|
self.global_dwc_m(g_x_w_m),
|
||||||
|
self.global_dwc_l(g_x_w_l)
|
||||||
|
), dim=1)))
|
||||||
|
x_w_attn = x_w_attn.view(b, c, 1, w_)
|
||||||
|
|
||||||
|
x = x * x_h_attn * x_w_attn
|
||||||
|
|
||||||
|
# Channel attention based on self attention
|
||||||
|
# reduce calculations
|
||||||
|
y = self.down_func(x)
|
||||||
|
y = self.conv_d(y)
|
||||||
|
_, _, h_, w_ = y.size()
|
||||||
|
|
||||||
|
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
|
||||||
|
y = self.norm(y)
|
||||||
|
q = self.q(y)
|
||||||
|
k = self.k(y)
|
||||||
|
v = self.v(y)
|
||||||
|
# (B, C, H, W) -> (B, head_num, head_dim, N)
|
||||||
|
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
|
||||||
|
# (B, head_num, head_dim, head_dim)
|
||||||
|
attn = q @ k.transpose(-2, -1) * self.scaler
|
||||||
|
attn = self.attn_drop(attn.softmax(dim=-1))
|
||||||
|
# (B, head_num, head_dim, N)
|
||||||
|
attn = attn @ v
|
||||||
|
# (B, C, H_, W_)
|
||||||
|
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
|
||||||
|
# (B, C, 1, 1)
|
||||||
|
attn = attn.mean((2, 3), keepdim=True)
|
||||||
|
attn = self.ca_gate(attn)
|
||||||
|
return attn * x
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
block = SCSA(
|
||||||
|
dim=256,
|
||||||
|
head_num=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = torch.rand(1, 256, 32, 32)
|
||||||
|
|
||||||
|
# 调用模块进行前向传播
|
||||||
|
output_tensor = block(input_tensor)
|
||||||
|
|
||||||
|
# 打印输入和输出张量的大小
|
||||||
|
print("Input size:", input_tensor.size())
|
||||||
|
print("Output size:", output_tensor.size())
|
12
net.py
12
net.py
@ -6,7 +6,8 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from componets.SMFA import SMFA
|
from componets.SCSA import SCSA
|
||||||
|
|
||||||
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
|
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
|
||||||
from componets.WTConvCV2 import WTConv2d
|
from componets.WTConvCV2 import WTConv2d
|
||||||
|
|
||||||
@ -159,7 +160,8 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = SMFA(dim=dim)
|
# self.token_mixer = SMFA(dim=dim)
|
||||||
|
self.token_mixer = SCSA(dim=dim,head_num=8)
|
||||||
|
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
@ -195,9 +197,11 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
)
|
)
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
)
|
)
|
||||||
|
pol = self.poolmlp(self.norm2(x))
|
||||||
|
|
||||||
x = wtConvX + self.drop_path(
|
x = wtConvX + self.drop_path(
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
* self.poolmlp(self.norm2(x)))
|
* pol)
|
||||||
else:
|
else:
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
@ -569,7 +573,7 @@ class Restormer_Decoder(nn.Module):
|
|||||||
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
||||||
stride=1, padding=1, bias=bias), )
|
stride=1, padding=1, bias=bias), )
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim)
|
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(in_channels=dim)
|
||||||
|
|
||||||
def forward(self, inp_img, base_feature, detail_feature):
|
def forward(self, inp_img, base_feature, detail_feature):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user