# -*- coding: utf-8 -*- ''' ------------------------------------------------------------------------------ Import packages ------------------------------------------------------------------------------ ''' from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import sys import time import datetime import torch import torch.nn as nn from torch.utils.data import DataLoader from utils.loss import Fusionloss, cc, relative_diff_loss import kornia print(torch.__version__) print(torch.cuda.is_available()) ''' ------------------------------------------------------------------------------ Configure our network ------------------------------------------------------------------------------ ''' os.environ['CUDA_VISIBLE_DEVICES'] = '0' criteria_fusion = Fusionloss() model_str = 'PFCFuse' # . Set the hyper-parameters for training num_epochs = 60 # total epoch epoch_gap = 40 # epoches of Phase I lr = 1e-4 weight_decay = 0 batch_size = 1 GPU_number = os.environ['CUDA_VISIBLE_DEVICES'] # Coefficients of the loss function coeff_mse_loss_VF = 1. # alpha1 coeff_mse_loss_IF = 1. coeff_rmi_loss_VF = 1. coeff_rmi_loss_IF = 1. coeff_cos_loss_VF = 1. coeff_cos_loss_IF = 1. coeff_decomp = 2. # alpha2 and alpha4 coeff_tv = 5. clip_grad_norm_value = 0.01 optim_step = 20 optim_gamma = 0.5 # 打印所有参数 print(f"Model: {model_str}") print(f"Number of epochs: {num_epochs}") print(f"Epoch gap: {epoch_gap}") print(f"Learning rate: {lr}") print(f"Weight decay: {weight_decay}") print(f"Batch size: {batch_size}") print(f"GPU number: {GPU_number}") print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}") print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}") print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}") print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}") print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}") print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}") print(f"Coefficient of Decomposition loss: {coeff_decomp}") print(f"Coefficient of Total Variation loss: {coeff_tv}") print(f"Clip gradient norm value: {clip_grad_norm_value}") print(f"Optimization step: {optim_step}") print(f"Optimization gamma: {optim_gamma}") # Model device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam( DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay) optimizer2 = torch.optim.Adam( DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay) optimizer3 = torch.optim.Adam( BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) optimizer4 = torch.optim.Adam( DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma) scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma) scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma) scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma) MSELoss = nn.MSELoss() L1Loss = nn.L1Loss() Loss_ssim = kornia.losses.SSIM(11, reduction='mean') HuberLoss = nn.HuberLoss() # data loader trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"), batch_size=batch_size, shuffle=True, num_workers=0) loader = {'train': trainloader, } timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M") ''' ------------------------------------------------------------------------------ Train ------------------------------------------------------------------------------ ''' step = 0 torch.backends.cudnn.benchmark = True prev_time = time.time() start = time.time() for epoch in range(num_epochs): ''' train ''' for i, (data_VIS, data_IR) in enumerate(loader['train']): data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() DIDF_Encoder.train() DIDF_Decoder.train() BaseFuseLayer.train() DetailFuseLayer.train() DIDF_Encoder.zero_grad() DIDF_Decoder.zero_grad() BaseFuseLayer.zero_grad() DetailFuseLayer.zero_grad() optimizer1.zero_grad() optimizer2.zero_grad() optimizer3.zero_grad() optimizer4.zero_grad() if epoch < epoch_gap: #Phase I feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True) data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D) cc_loss_B = cc(feature_V_B, feature_I_B) cc_loss_D = cc(feature_V_D, feature_I_D) mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat) mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat) # print("mse_loss_V", mse_loss_V) # print("mse_loss_I", mse_loss_I) Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS), kornia.filters.SpatialGradient()(data_VIS_hat)) # print("Gradient_loss", Gradient_loss) loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) # print("loss_decomp", loss_decomp) loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat) loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat) loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \ mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \ coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v loss.backward() nn.utils.clip_grad_norm_( DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) optimizer1.step() optimizer2.step() else: #Phase II feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS) feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR) feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B) feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D) data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse) mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse) cc_loss_B = cc(feature_V_B, feature_I_B) cc_loss_D = cc(feature_V_D, feature_I_D) loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B) fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse) loss = fusionloss + coeff_decomp * loss_decomp loss.backward() nn.utils.clip_grad_norm_( DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) optimizer1.step() optimizer2.step() optimizer3.step() optimizer4.step() # Determine approximate time left batches_done = epoch * len(loader['train']) + i batches_left = num_epochs * len(loader['train']) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) epoch_time = time.time() - prev_time prev_time = time.time() sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" % ( epoch, num_epochs, i, len(loader['train']), loss.item(), time_left, ) ) # adjust the learning rate scheduler1.step() scheduler2.step() if not epoch < epoch_gap: scheduler3.step() scheduler4.step() if optimizer1.param_groups[0]['lr'] <= 1e-6: optimizer1.param_groups[0]['lr'] = 1e-6 if optimizer2.param_groups[0]['lr'] <= 1e-6: optimizer2.param_groups[0]['lr'] = 1e-6 if optimizer3.param_groups[0]['lr'] <= 1e-6: optimizer3.param_groups[0]['lr'] = 1e-6 if optimizer4.param_groups[0]['lr'] <= 1e-6: optimizer4.param_groups[0]['lr'] = 1e-6 if True: checkpoint = { 'DIDF_Encoder': DIDF_Encoder.state_dict(), 'DIDF_Decoder': DIDF_Decoder.state_dict(), 'BaseFuseLayer': BaseFuseLayer.state_dict(), 'DetailFuseLayer': DetailFuseLayer.state_dict(), } torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))