diff --git a/componets/TIAM(CV).py b/componets/TIAM.py similarity index 100% rename from componets/TIAM(CV).py rename to componets/TIAM.py diff --git a/net.py b/net.py index 75df2bf..9e623f1 100644 --- a/net.py +++ b/net.py @@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange +from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared from componets.WTConvCV2 import WTConv2d @@ -538,6 +539,7 @@ class Restormer_Encoder(nn.Module): return base_feature, detail_feature, out_enc_level1 # 1 64 128 128 + class Restormer_Decoder(nn.Module): def __init__(self, inp_channels=1, @@ -561,8 +563,12 @@ class Restormer_Decoder(nn.Module): nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), ) self.sigmoid = nn.Sigmoid() + self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim) def forward(self, inp_img, base_feature, detail_feature): + + base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature) + out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) out_enc_level0 = self.reduce_channel(out_enc_level0) out_enc_level1 = self.encoder_level2(out_enc_level0)