feat(net): 添加时空注意力机制- 在 net.py 中引入 SpatiotemporalAttentionFullNotWeightShared 模块
- 在 Restormer_Decoder 类中添加时空注意力机制处理基础特征和细节特征
This commit is contained in:
parent
f4b3a933bf
commit
85dc7a92ed
6
net.py
6
net.py
@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
|
||||||
from componets.WTConvCV2 import WTConv2d
|
from componets.WTConvCV2 import WTConv2d
|
||||||
|
|
||||||
|
|
||||||
@ -538,6 +539,7 @@ class Restormer_Encoder(nn.Module):
|
|||||||
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
|
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Restormer_Decoder(nn.Module):
|
class Restormer_Decoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
inp_channels=1,
|
inp_channels=1,
|
||||||
@ -561,8 +563,12 @@ class Restormer_Decoder(nn.Module):
|
|||||||
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
||||||
stride=1, padding=1, bias=bias), )
|
stride=1, padding=1, bias=bias), )
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim)
|
||||||
|
|
||||||
def forward(self, inp_img, base_feature, detail_feature):
|
def forward(self, inp_img, base_feature, detail_feature):
|
||||||
|
|
||||||
|
base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature)
|
||||||
|
|
||||||
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
||||||
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
||||||
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
||||||
|
Loading…
Reference in New Issue
Block a user