From 97501d082bcf427c37c0025e2718acc1dd8e3048 Mon Sep 17 00:00:00 2001 From: zjut Date: Fri, 15 Nov 2024 17:32:41 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E6=9B=BF=E6=8D=A2=20token=5Fmixer?= =?UTF-8?q?=20=E4=B8=BA=20SCSA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 Pooling 层,使用 SCSA(Spectral Correlation and Spatial Attention)层作为 token_mixer - 更新模型架构,以适应新的 SCSA 层 - 修改测试脚本,针对特定数据集进行测试 --- net.py | 4 +++- test_IVF.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/net.py b/net.py index 09bee70..7218ce3 100644 --- a/net.py +++ b/net.py @@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange +from componets.SCSA import SCSA def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -101,7 +102,8 @@ class BaseFeatureExtractionSAR(nn.Module): super().__init__() self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.token_mixer = SCSA(dim=dim, head_num=8) self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, diff --git a/test_IVF.py b/test_IVF.py index c402ae7..f731c7f 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-08-16-20.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-11-11.pth" -for dataset_name in ["TNO","RoadScene"]: +for dataset_name in ["sar"]: print("\n"*2+"="*80) - model_name="PFCFuse " + model_name="whai 修改SCSA分支 " print("The test result of "+dataset_name+' :') test_folder = os.path.join('test_img', dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name)