feat(net): 替换 token_mixer 为 SCSA

- 移除 Pooling 层,使用 SCSA(Spectral Correlation and Spatial Attention)层作为 token_mixer
- 更新模型架构,以适应新的 SCSA 层
- 修改测试脚本,针对特定数据集进行测试
This commit is contained in:
zjut 2024-11-15 17:32:41 +08:00
parent 4f805c2449
commit 97501d082b
2 changed files with 6 additions and 4 deletions

4
net.py
View File

@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.SCSA import SCSA
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training: if drop_prob == 0. or not training:
@ -101,7 +102,8 @@ class BaseFeatureExtractionSAR(nn.Module):
super().__init__() super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.token_mixer = SCSA(dim=dim, head_num=8)
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,

View File

@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-08-16-20.pth" ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-11-11.pth"
for dataset_name in ["TNO","RoadScene"]: for dataset_name in ["sar"]:
print("\n"*2+"="*80) print("\n"*2+"="*80)
model_name="PFCFuse " model_name="whai 修改SCSA分支 "
print("The test result of "+dataset_name+' :') print("The test result of "+dataset_name+' :')
test_folder = os.path.join('test_img', dataset_name) test_folder = os.path.join('test_img', dataset_name)
test_out_folder=os.path.join('test_result',current_time,dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name)