feat(net): 替换 token_mixer 为 SCSA
- 移除 Pooling 层,使用 SCSA(Spectral Correlation and Spatial Attention)层作为 token_mixer - 更新模型架构,以适应新的 SCSA 层 - 修改测试脚本,针对特定数据集进行测试
This commit is contained in:
parent
4f805c2449
commit
97501d082b
4
net.py
4
net.py
@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from einops import rearrange
|
||||
|
||||
from componets.SCSA import SCSA
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
if drop_prob == 0. or not training:
|
||||
@ -101,7 +102,8 @@ class BaseFeatureExtractionSAR(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||
self.token_mixer = SCSA(dim=dim, head_num=8)
|
||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
|
@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-08-16-20.pth"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-11-11.pth"
|
||||
|
||||
for dataset_name in ["TNO","RoadScene"]:
|
||||
for dataset_name in ["sar"]:
|
||||
print("\n"*2+"="*80)
|
||||
model_name="PFCFuse "
|
||||
model_name="whai 修改SCSA分支 "
|
||||
print("The test result of "+dataset_name+' :')
|
||||
test_folder = os.path.join('test_img', dataset_name)
|
||||
test_out_folder=os.path.join('test_result',current_time,dataset_name)
|
||||
|
Loading…
Reference in New Issue
Block a user