feat(model): 重构模型并添加新功能
- 重新组织了模型结构,增加了新的特征融合模块 - 添加了深度可分离卷积块和新的细节特征提取模块 - 更新了数据处理流程,使用了新的数据集路径 - 调整了训练参数,增加了训练轮次和学习率- 优化了损失函数,使用了Huber损失替代MSE损失
This commit is contained in:
parent
e1a339e04b
commit
fa8106838e
@ -1,2 +1,5 @@
|
|||||||
.idea/
|
.idea/
|
||||||
status.md
|
status.md
|
||||||
|
data/
|
||||||
|
test_img/
|
||||||
|
test_result/
|
||||||
|
@ -39,16 +39,17 @@ def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
|||||||
ratio = (limits[1] - limits[0]) / limits[1]
|
ratio = (limits[1] - limits[0]) / limits[1]
|
||||||
return ratio < fraction_threshold
|
return ratio < fraction_threshold
|
||||||
|
|
||||||
data_name="MSRS_train"
|
data_name="YYX_sar_opr_data"
|
||||||
img_size=128 #patch size
|
img_size=256 #patch size
|
||||||
stride=200 #patch stride
|
stride=200 #patch stride
|
||||||
|
|
||||||
IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir"))
|
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/SAR_1"))
|
||||||
VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi"))
|
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/OPR_1"))
|
||||||
|
|
||||||
assert len(IR_files) == len(VIS_files)
|
assert len(IR_files) == len(VIS_files)
|
||||||
h5f = h5py.File(os.path.join('.\\data',
|
h5path= os.path.join('/home/star/whaiDir/PFCFuse/data/',
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5')
|
||||||
|
h5f = h5py.File(h5path,
|
||||||
'w')
|
'w')
|
||||||
h5_ir = h5f.create_group('ir_patchs')
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
h5_vis = h5f.create_group('vis_patchs')
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
@ -80,8 +81,7 @@ for i in tqdm(range(len(IR_files))):
|
|||||||
|
|
||||||
h5f.close()
|
h5f.close()
|
||||||
|
|
||||||
with h5py.File(os.path.join('data',
|
with h5py.File(h5path,"r") as f:
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
print(f[key], key, f[key].name)
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
271
net.py
271
net.py
@ -148,8 +148,28 @@ class PoolMlp(nn.Module):
|
|||||||
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
# return x
|
# return x
|
||||||
|
class DetailFeatureFusion(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureFusion, self).__init__()
|
||||||
|
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
|
|
||||||
class BaseFeatureExtraction(nn.Module):
|
def forward(self, x): # 1 64 128 128
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
|
# 增强并添加残差连接
|
||||||
|
enhanced_z1 = self.enhancement_module(z1)
|
||||||
|
enhanced_z2 = self.enhancement_module(z2)
|
||||||
|
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
|
||||||
|
# 残差连接
|
||||||
|
z1 = z1 + enhanced_z1
|
||||||
|
z2 = z2 + enhanced_z2
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
class BaseFeatureFusion(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
# norm_layer=nn.LayerNorm,
|
# norm_layer=nn.LayerNorm,
|
||||||
@ -158,7 +178,6 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
# self.token_mixer = SMFA(dim=dim)
|
# self.token_mixer = SMFA(dim=dim)
|
||||||
self.token_mixer = SCSA(dim=dim,head_num=8)
|
self.token_mixer = SCSA(dim=dim,head_num=8)
|
||||||
@ -184,7 +203,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
# self.layer_scale_1(64,)
|
# self.layer_scale_1(64,)
|
||||||
wtConvX = self.WTConv2d(x)
|
# wtConvX = self.WTConv2d(x)
|
||||||
|
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
@ -192,16 +211,124 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
x = (x +
|
x = (x +
|
||||||
self.drop_path(
|
self.drop_path(tmp1 * token_mix)
|
||||||
tmp1 * token_mix
|
|
||||||
)
|
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
)
|
)
|
||||||
pol = self.poolmlp(self.norm2(x))
|
# pol = self.poolmlp(self.norm2(x))
|
||||||
|
#
|
||||||
|
# x = x + self.drop_path(
|
||||||
|
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
# * pol)
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
x = wtConvX + self.drop_path(
|
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* pol)
|
class BaseFeatureExtraction(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x): # 1 64 128 128
|
||||||
|
if self.use_layer_scale:
|
||||||
|
# self.layer_scale_1(64,)
|
||||||
|
# wtConvX = self.WTConv2d(x)
|
||||||
|
|
||||||
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
|
x = (x +
|
||||||
|
self.drop_path(tmp1 * token_mix)
|
||||||
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
)
|
||||||
|
# pol = self.poolmlp(self.norm2(x))
|
||||||
|
#
|
||||||
|
# x = x + self.drop_path(
|
||||||
|
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
# * pol)
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BaseFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
# self.token_mixer = SMFA(dim=dim)
|
||||||
|
self.token_mixer = SCSA(dim=dim,head_num=8)
|
||||||
|
|
||||||
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
# self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
# act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x): # 1 64 128 128
|
||||||
|
if self.use_layer_scale:
|
||||||
|
# self.layer_scale_1(64,)
|
||||||
|
# wtConvX = self.WTConv2d(x)
|
||||||
|
|
||||||
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
|
|
||||||
|
x = (x +
|
||||||
|
self.drop_path(tmp1 * token_mix)
|
||||||
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
)
|
||||||
|
# pol = self.poolmlp(self.norm2(x))
|
||||||
|
#
|
||||||
|
# x = x + self.drop_path(
|
||||||
|
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
# * pol)
|
||||||
else:
|
else:
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
@ -228,12 +355,38 @@ class InvertedResidualBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bottleneckBlock(x)
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseSeparableConvBlock(nn.Module):
|
||||||
|
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
|
||||||
|
super(DepthwiseSeparableConvBlock, self).__init__()
|
||||||
|
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
|
||||||
|
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
|
||||||
|
self.bn = nn.BatchNorm2d(oup)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.depthwise(x)
|
||||||
|
x = self.pointwise(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
|
|
||||||
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
||||||
def __init__(self):
|
def __init__(self,useBlock=0):
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
|
|
||||||
|
if useBlock==0:
|
||||||
|
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
elif useBlock==1:
|
||||||
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
else:
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
@ -251,26 +404,29 @@ class DetailNode(nn.Module):
|
|||||||
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||||
return z1, z2
|
return z1, z2
|
||||||
|
|
||||||
|
class DetailFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureExtractionSAR, self).__init__()
|
||||||
|
# useBlock = 1表示使用 invresblock
|
||||||
|
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
self.enhancement_module = WTConv2d(32, 32)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
|
||||||
# 增强并添加残差连接
|
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
|
|
||||||
# 残差连接
|
|
||||||
z1 = z1 + enhanced_z1
|
|
||||||
z2 = z2 + enhanced_z2
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -435,78 +591,9 @@ class OverlapPatchEmbed(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BaseFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
# norm_layer=nn.LayerNorm,
|
|
||||||
drop=0., drop_path=0.,
|
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
||||||
act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
||||||
else nn.Identity()
|
|
||||||
self.use_layer_scale = use_layer_scale
|
|
||||||
|
|
||||||
if use_layer_scale:
|
|
||||||
self.layer_scale_1 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
self.layer_scale_2 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
if self.use_layer_scale:
|
|
||||||
# self.layer_scale_1(64,)
|
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
|
||||||
|
|
||||||
x = self.WTConv2d(x)
|
|
||||||
|
|
||||||
x = (x +
|
|
||||||
self.drop_path(
|
|
||||||
tmp1 * token_mix
|
|
||||||
)
|
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
|
||||||
)
|
|
||||||
x = x + self.drop_path(
|
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* self.poolmlp(self.norm2(x)))
|
|
||||||
else:
|
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DetailFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureExtractionSAR, self).__init__()
|
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
self.enhancement_module = WTConv2d(32, 32)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
|
||||||
# 增强并添加残差连接
|
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
|
||||||
# 残差连接
|
|
||||||
z1 = z1 + enhanced_z1
|
|
||||||
z2 = z2 + enhanced_z2
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||||||
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-08-16-20.pth"
|
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-11-20-36.pth"
|
||||||
|
|
||||||
for dataset_name in ["TNO","RoadScene"]:
|
for dataset_name in ["sar"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
model_name="PFCFuse "
|
model_name="PFCFuse "
|
||||||
print("The test result of "+dataset_name+' :')
|
print("The test result of "+dataset_name+' :')
|
||||||
|
26
train.py
26
train.py
@ -6,7 +6,8 @@ Import packages
|
|||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
|
||||||
|
DetailFeatureFusion
|
||||||
from utils.dataset import H5Dataset
|
from utils.dataset import H5Dataset
|
||||||
import os
|
import os
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
@ -31,11 +32,11 @@ Configure our network
|
|||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
criteria_fusion = Fusionloss()
|
criteria_fusion = Fusionloss()
|
||||||
model_str = 'PFCFuse'
|
model_str = 'WhaiFuse'
|
||||||
|
|
||||||
# . Set the hyper-parameters for training
|
# . Set the hyper-parameters for training
|
||||||
num_epochs = 60 # total epoch
|
num_epochs = 120 # total epoch
|
||||||
epoch_gap = 40 # epoches of Phase I
|
epoch_gap = 80 # epoches of Phase I
|
||||||
|
|
||||||
lr = 1e-4
|
lr = 1e-4
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
# optimizer, scheduler and loss function
|
||||||
optimizer1 = torch.optim.Adam(
|
optimizer1 = torch.optim.Adam(
|
||||||
@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
|||||||
HuberLoss = nn.HuberLoss()
|
HuberLoss = nn.HuberLoss()
|
||||||
|
|
||||||
# data loader
|
# data loader
|
||||||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"),
|
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_128_stride_200.h5"),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0)
|
num_workers=0)
|
||||||
@ -156,6 +157,7 @@ for epoch in range(num_epochs):
|
|||||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
|
|
||||||
|
# HuberLoss 对比CDDFUSE的MSELoss
|
||||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
||||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
||||||
|
|
||||||
@ -169,6 +171,16 @@ for epoch in range(num_epochs):
|
|||||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||||
# print("loss_decomp", loss_decomp)
|
# print("loss_decomp", loss_decomp)
|
||||||
|
|
||||||
|
"""
|
||||||
|
-Huber 损失函数是一种鲁棒的回归损失函数,
|
||||||
|
结合了均方误差(MSE)和绝对误差(MAE)的优点。
|
||||||
|
其设计目的是为了减少在数据中出现异常值(outliers)时对模型训练的影响。
|
||||||
|
在传统的 MSE 损失函数中,对大的误差给予较大的惩罚,
|
||||||
|
而 MAE 则对所有误差给予同等的线性惩罚。
|
||||||
|
然而,MSE 在面对异常值时可能导致模型的不稳定,
|
||||||
|
而 MAE 则可能导致梯度消失的问题。
|
||||||
|
a-Huber 损失函数通过一个阈值 δ 来自适应地选择惩罚方式:
|
||||||
|
"""
|
||||||
|
|
||||||
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
||||||
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
||||||
|
Loading…
Reference in New Issue
Block a user