feat(net): 添加时空注意力机制- 在 net.py 中引入 SpatiotemporalAttentionFullNotWeightShared 模块

- 在 Restormer_Decoder 类中添加时空注意力机制处理基础特征和细节特征
This commit is contained in:
whaifree 2024-10-26 19:18:01 +08:00
parent f4b3a933bf
commit 85dc7a92ed
2 changed files with 6 additions and 0 deletions

6
net.py
View File

@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
from componets.WTConvCV2 import WTConv2d from componets.WTConvCV2 import WTConv2d
@ -538,6 +539,7 @@ class Restormer_Encoder(nn.Module):
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128 return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
class Restormer_Decoder(nn.Module): class Restormer_Decoder(nn.Module):
def __init__(self, def __init__(self,
inp_channels=1, inp_channels=1,
@ -561,8 +563,12 @@ class Restormer_Decoder(nn.Module):
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), ) stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim)
def forward(self, inp_img, base_feature, detail_feature): def forward(self, inp_img, base_feature, detail_feature):
base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature)
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0) out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0) out_enc_level1 = self.encoder_level2(out_enc_level0)