From fa8106838eb424930abe5a1c7f3250f7f0c571a3 Mon Sep 17 00:00:00 2001 From: zjut Date: Tue, 12 Nov 2024 10:37:56 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E9=87=8D=E6=9E=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=96=B0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重新组织了模型结构,增加了新的特征融合模块 - 添加了深度可分离卷积块和新的细节特征提取模块 - 更新了数据处理流程,使用了新的数据集路径 - 调整了训练参数,增加了训练轮次和学习率- 优化了损失函数,使用了Huber损失替代MSE损失 --- .gitignore | 3 + dataprocessing.py | 40 +++---- net.py | 277 ++++++++++++++++++++++++++++++---------------- test_IVF.py | 4 +- train.py | 26 +++-- 5 files changed, 226 insertions(+), 124 deletions(-) diff --git a/ .gitignore b/ .gitignore index edf7231..0a1608b 100644 --- a/ .gitignore +++ b/ .gitignore @@ -1,2 +1,5 @@ .idea/ status.md + data/ + test_img/ + test_result/ diff --git a/dataprocessing.py b/dataprocessing.py index 83bb220..29e5599 100644 --- a/dataprocessing.py +++ b/dataprocessing.py @@ -12,7 +12,7 @@ def get_img_file(file_name): if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')): imagelist.append(os.path.join(parent, filename)) return imagelist - + def rgb2y(img): y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000 return y @@ -39,16 +39,17 @@ def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10, ratio = (limits[1] - limits[0]) / limits[1] return ratio < fraction_threshold -data_name="MSRS_train" -img_size=128 #patch size +data_name="YYX_sar_opr_data" +img_size=256 #patch size stride=200 #patch stride -IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir")) -VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi")) +IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/SAR_1")) +VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/OPR_1")) assert len(IR_files) == len(VIS_files) -h5f = h5py.File(os.path.join('.\\data', - data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), +h5path= os.path.join('/home/star/whaiDir/PFCFuse/data/', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5') +h5f = h5py.File(h5path, 'w') h5_ir = h5f.create_group('ir_patchs') h5_vis = h5f.create_group('vis_patchs') @@ -57,11 +58,11 @@ for i in tqdm(range(len(IR_files))): I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32 I_VIS = rgb2y(I_VIS) # [1, H, W] Float32 I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32 - - # crop + + # crop I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride) I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12) - + for ii in range(I_IR_Patch_Group.shape[-1]): bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii]) bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii]) @@ -72,22 +73,21 @@ for i in tqdm(range(len(IR_files))): avl_IR=avl_IR[None,...] avl_VIS=avl_VIS[None,...] - h5_ir.create_dataset(str(train_num), data=avl_IR, + h5_ir.create_dataset(str(train_num), data=avl_IR, dtype=avl_IR.dtype, shape=avl_IR.shape) - h5_vis.create_dataset(str(train_num), data=avl_VIS, + h5_vis.create_dataset(str(train_num), data=avl_VIS, dtype=avl_VIS.dtype, shape=avl_VIS.shape) - train_num += 1 + train_num += 1 h5f.close() -with h5py.File(os.path.join('data', - data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f: +with h5py.File(h5path,"r") as f: for key in f.keys(): - print(f[key], key, f[key].name) - - - + print(f[key], key, f[key].name) + + + + - diff --git a/net.py b/net.py index 3d457ae..de61ac4 100644 --- a/net.py +++ b/net.py @@ -148,8 +148,28 @@ class PoolMlp(nn.Module): # x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse # x = x + self.drop_path(self.poolmlp(self.norm2(x))) # return x +class DetailFeatureFusion(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureFusion, self).__init__() + INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + self.enhancement_module = WTConv2d(32, 32) -class BaseFeatureExtraction(nn.Module): + def forward(self, x): # 1 64 128 128 + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 + # 增强并添加残差连接 + enhanced_z1 = self.enhancement_module(z1) + enhanced_z2 = self.enhancement_module(z2) + + for layer in self.net: + z1, z2 = layer(z1, z2) + + # 残差连接 + z1 = z1 + enhanced_z1 + z2 = z2 + enhanced_z2 + return torch.cat((z1, z2), dim=1) + +class BaseFeatureFusion(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, # norm_layer=nn.LayerNorm, @@ -158,7 +178,6 @@ class BaseFeatureExtraction(nn.Module): super().__init__() - self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') # self.token_mixer = SMFA(dim=dim) self.token_mixer = SCSA(dim=dim,head_num=8) @@ -184,7 +203,7 @@ class BaseFeatureExtraction(nn.Module): def forward(self, x): # 1 64 128 128 if self.use_layer_scale: # self.layer_scale_1(64,) - wtConvX = self.WTConv2d(x) + # wtConvX = self.WTConv2d(x) tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 normal = self.norm1(x) # 1 64 128 128 @@ -192,16 +211,124 @@ class BaseFeatureExtraction(nn.Module): x = (x + - self.drop_path( - tmp1 * token_mix - ) + self.drop_path(tmp1 * token_mix) # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 ) - pol = self.poolmlp(self.norm2(x)) + # pol = self.poolmlp(self.norm2(x)) + # + # x = x + self.drop_path( + # self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + # * pol) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x - x = wtConvX + self.drop_path( - self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) - * pol) + + +class BaseFeatureExtraction(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): # 1 64 128 128 + if self.use_layer_scale: + # self.layer_scale_1(64,) + # wtConvX = self.WTConv2d(x) + + tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 + normal = self.norm1(x) # 1 64 128 128 + token_mix = self.token_mixer(normal) # 1 64 128 128 + + x = (x + + self.drop_path(tmp1 * token_mix) + # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 + ) + # pol = self.poolmlp(self.norm2(x)) + # + # x = x + self.drop_path( + # self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + # * pol) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + +class BaseFeatureExtractionSAR(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + + self.norm1 = LayerNorm(dim, 'WithBias') + # self.token_mixer = SMFA(dim=dim) + self.token_mixer = SCSA(dim=dim,head_num=8) + + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + # self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + # self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + # act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): # 1 64 128 128 + if self.use_layer_scale: + # self.layer_scale_1(64,) + # wtConvX = self.WTConv2d(x) + + tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 + normal = self.norm1(x) # 1 64 128 128 + token_mix = self.token_mixer(normal) # 1 64 128 128 + + + x = (x + + self.drop_path(tmp1 * token_mix) + # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 + ) + # pol = self.poolmlp(self.norm2(x)) + # + # x = x + self.drop_path( + # self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + # * pol) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse x = x + self.drop_path(self.poolmlp(self.norm2(x))) @@ -228,15 +355,41 @@ class InvertedResidualBlock(nn.Module): def forward(self, x): return self.bottleneckBlock(x) + +class DepthwiseSeparableConvBlock(nn.Module): + def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConvBlock, self).__init__() + self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) + self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) + self.bn = nn.BatchNorm2d(oup) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x + + class DetailNode(nn.Module): # ' - def __init__(self): + def __init__(self,useBlock=0): super(DetailNode, self).__init__() - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + if useBlock==0: + self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) + elif useBlock==1: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + else: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True) @@ -251,26 +404,29 @@ class DetailNode(nn.Module): z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) return z1, z2 +class DetailFeatureExtractionSAR(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureExtractionSAR, self).__init__() + # useBlock = 1表示使用 invresblock + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) - self.enhancement_module = WTConv2d(32, 32) - - def forward(self, x): # 1 64 128 128 - z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # 增强并添加残差连接 - enhanced_z1 = self.enhancement_module(z1) - enhanced_z2 = self.enhancement_module(z2) + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] for layer in self.net: z1, z2 = layer(z1, z2) - - # 残差连接 - z1 = z1 + enhanced_z1 - z2 = z2 + enhanced_z2 return torch.cat((z1, z2), dim=1) # ============================================================================= @@ -435,78 +591,9 @@ class OverlapPatchEmbed(nn.Module): return x -class BaseFeatureExtractionSAR(nn.Module): - def __init__(self, dim, pool_size=3, mlp_ratio=4., - act_layer=nn.GELU, - # norm_layer=nn.LayerNorm, - drop=0., drop_path=0., - use_layer_scale=True, layer_scale_init_value=1e-5): - - super().__init__() - - self.WTConv2d = WTConv2d(dim, dim) - self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 - self.norm2 = LayerNorm(dim, 'WithBias') - mlp_hidden_dim = int(dim * mlp_ratio) - self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop) - - # The following two techniques are useful to train deep PoolFormers. - self.drop_path = DropPath(drop_path) if drop_path > 0. \ - else nn.Identity() - self.use_layer_scale = use_layer_scale - - if use_layer_scale: - self.layer_scale_1 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - self.layer_scale_2 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - def forward(self, x): # 1 64 128 128 - if self.use_layer_scale: - # self.layer_scale_1(64,) - tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 - normal = self.norm1(x) # 1 64 128 128 - token_mix = self.token_mixer(normal) # 1 64 128 128 - - x = self.WTConv2d(x) - - x = (x + - self.drop_path( - tmp1 * token_mix - ) - # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 - ) - x = x + self.drop_path( - self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) - * self.poolmlp(self.norm2(x))) - else: - x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse - x = x + self.drop_path(self.poolmlp(self.norm2(x))) - return x -class DetailFeatureExtractionSAR(nn.Module): - def __init__(self, num_layers=3): - super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] - self.net = nn.Sequential(*INNmodules) - self.enhancement_module = WTConv2d(32, 32) - - def forward(self, x): # 1 64 128 128 - z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # 增强并添加残差连接 - enhanced_z1 = self.enhancement_module(z1) - enhanced_z2 = self.enhancement_module(z2) - # 残差连接 - z1 = z1 + enhanced_z1 - z2 = z2 + enhanced_z2 - for layer in self.net: - z1, z2 = layer(z1, z2) - return torch.cat((z1, z2), dim=1) diff --git a/test_IVF.py b/test_IVF.py index c402ae7..488bae4 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,9 +17,9 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-08-16-20.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-11-20-36.pth" -for dataset_name in ["TNO","RoadScene"]: +for dataset_name in ["sar"]: print("\n"*2+"="*80) model_name="PFCFuse " print("The test result of "+dataset_name+' :') diff --git a/train.py b/train.py index 81d4d00..de853e8 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,8 @@ Import packages ------------------------------------------------------------------------------ ''' -from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \ + DetailFeatureFusion from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' @@ -31,11 +32,11 @@ Configure our network os.environ['CUDA_VISIBLE_DEVICES'] = '0' criteria_fusion = Fusionloss() -model_str = 'PFCFuse' +model_str = 'WhaiFuse' # . Set the hyper-parameters for training -num_epochs = 60 # total epoch -epoch_gap = 40 # epoches of Phase I +num_epochs = 120 # total epoch +epoch_gap = 80 # epoches of Phase I lr = 1e-4 weight_decay = 0 @@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) -BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam( @@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean') HuberLoss = nn.HuberLoss() # data loader -trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"), +trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_128_stride_200.h5"), batch_size=batch_size, shuffle=True, num_workers=0) @@ -156,6 +157,7 @@ for epoch in range(num_epochs): cc_loss_B = cc(feature_V_B, feature_I_B) cc_loss_D = cc(feature_V_D, feature_I_D) + # HuberLoss 对比CDDFUSE的MSELoss mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat) mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat) @@ -169,6 +171,16 @@ for epoch in range(num_epochs): loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) # print("loss_decomp", loss_decomp) + """ + -Huber 损失函数是一种鲁棒的回归损失函数, + 结合了均方误差(MSE)和绝对误差(MAE)的优点。 + 其设计目的是为了减少在数据中出现异常值(outliers)时对模型训练的影响。 + 在传统的 MSE 损失函数中,对大的误差给予较大的惩罚, + 而 MAE 则对所有误差给予同等的线性惩罚。 + 然而,MSE 在面对异常值时可能导致模型的不稳定, + 而 MAE 则可能导致梯度消失的问题。 + a-Huber 损失函数通过一个阈值 δ 来自适应地选择惩罚方式: + """ loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat) loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)