From 85dc7a92edb8bb7ffdc3096d2201bfd48bab09a8 Mon Sep 17 00:00:00 2001 From: whaifree Date: Sat, 26 Oct 2024 19:18:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E6=B7=BB=E5=8A=A0=E6=97=B6?= =?UTF-8?q?=E7=A9=BA=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=9C=BA=E5=88=B6-=20?= =?UTF-8?q?=E5=9C=A8=20net.py=20=E4=B8=AD=E5=BC=95=E5=85=A5=20Spatiotempor?= =?UTF-8?q?alAttentionFullNotWeightShared=20=E6=A8=A1=E5=9D=97=20-=20?= =?UTF-8?q?=E5=9C=A8=20Restormer=5FDecoder=20=E7=B1=BB=E4=B8=AD=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=97=B6=E7=A9=BA=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=9C=BA?= =?UTF-8?q?=E5=88=B6=E5=A4=84=E7=90=86=E5=9F=BA=E7=A1=80=E7=89=B9=E5=BE=81?= =?UTF-8?q?=E5=92=8C=E7=BB=86=E8=8A=82=E7=89=B9=E5=BE=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- componets/{TIAM(CV).py => TIAM.py} | 0 net.py | 6 ++++++ 2 files changed, 6 insertions(+) rename componets/{TIAM(CV).py => TIAM.py} (100%) diff --git a/componets/TIAM(CV).py b/componets/TIAM.py similarity index 100% rename from componets/TIAM(CV).py rename to componets/TIAM.py diff --git a/net.py b/net.py index 75df2bf..9e623f1 100644 --- a/net.py +++ b/net.py @@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange +from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared from componets.WTConvCV2 import WTConv2d @@ -538,6 +539,7 @@ class Restormer_Encoder(nn.Module): return base_feature, detail_feature, out_enc_level1 # 1 64 128 128 + class Restormer_Decoder(nn.Module): def __init__(self, inp_channels=1, @@ -561,8 +563,12 @@ class Restormer_Decoder(nn.Module): nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), ) self.sigmoid = nn.Sigmoid() + self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim) def forward(self, inp_img, base_feature, detail_feature): + + base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature) + out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) out_enc_level0 = self.reduce_channel(out_enc_level0) out_enc_level1 = self.encoder_level2(out_enc_level0)