fa8106838e
- 重新组织了模型结构,增加了新的特征融合模块 - 添加了深度可分离卷积块和新的细节特征提取模块 - 更新了数据处理流程,使用了新的数据集路径 - 调整了训练参数,增加了训练轮次和学习率- 优化了损失函数,使用了Huber损失替代MSE损失
277 lines
10 KiB
Python
277 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
'''
|
||
------------------------------------------------------------------------------
|
||
Import packages
|
||
------------------------------------------------------------------------------
|
||
'''
|
||
|
||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
|
||
DetailFeatureFusion
|
||
from utils.dataset import H5Dataset
|
||
import os
|
||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||
import sys
|
||
import time
|
||
import datetime
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader
|
||
from utils.loss import Fusionloss, cc, relative_diff_loss
|
||
import kornia
|
||
|
||
print(torch.__version__)
|
||
print(torch.cuda.is_available())
|
||
|
||
'''
|
||
------------------------------------------------------------------------------
|
||
Configure our network
|
||
------------------------------------------------------------------------------
|
||
'''
|
||
|
||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||
criteria_fusion = Fusionloss()
|
||
model_str = 'WhaiFuse'
|
||
|
||
# . Set the hyper-parameters for training
|
||
num_epochs = 120 # total epoch
|
||
epoch_gap = 80 # epoches of Phase I
|
||
|
||
lr = 1e-4
|
||
weight_decay = 0
|
||
batch_size = 1
|
||
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||
# Coefficients of the loss function
|
||
coeff_mse_loss_VF = 1. # alpha1
|
||
coeff_mse_loss_IF = 1.
|
||
|
||
coeff_rmi_loss_VF = 1.
|
||
coeff_rmi_loss_IF = 1.
|
||
|
||
coeff_cos_loss_VF = 1.
|
||
coeff_cos_loss_IF = 1.
|
||
coeff_decomp = 2. # alpha2 and alpha4
|
||
coeff_tv = 5.
|
||
|
||
clip_grad_norm_value = 0.01
|
||
optim_step = 20
|
||
optim_gamma = 0.5
|
||
|
||
# 打印所有参数
|
||
print(f"Model: {model_str}")
|
||
print(f"Number of epochs: {num_epochs}")
|
||
print(f"Epoch gap: {epoch_gap}")
|
||
print(f"Learning rate: {lr}")
|
||
print(f"Weight decay: {weight_decay}")
|
||
print(f"Batch size: {batch_size}")
|
||
print(f"GPU number: {GPU_number}")
|
||
|
||
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
|
||
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
|
||
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
|
||
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
|
||
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
|
||
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
|
||
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
|
||
print(f"Coefficient of Total Variation loss: {coeff_tv}")
|
||
|
||
print(f"Clip gradient norm value: {clip_grad_norm_value}")
|
||
print(f"Optimization step: {optim_step}")
|
||
print(f"Optimization gamma: {optim_gamma}")
|
||
|
||
|
||
# Model
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
|
||
|
||
# optimizer, scheduler and loss function
|
||
optimizer1 = torch.optim.Adam(
|
||
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||
optimizer2 = torch.optim.Adam(
|
||
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||
optimizer3 = torch.optim.Adam(
|
||
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||
optimizer4 = torch.optim.Adam(
|
||
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||
|
||
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||
|
||
MSELoss = nn.MSELoss()
|
||
L1Loss = nn.L1Loss()
|
||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||
HuberLoss = nn.HuberLoss()
|
||
|
||
# data loader
|
||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_128_stride_200.h5"),
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=0)
|
||
|
||
loader = {'train': trainloader, }
|
||
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||
|
||
'''
|
||
------------------------------------------------------------------------------
|
||
Train
|
||
------------------------------------------------------------------------------
|
||
'''
|
||
|
||
step = 0
|
||
torch.backends.cudnn.benchmark = True
|
||
prev_time = time.time()
|
||
|
||
start = time.time()
|
||
for epoch in range(num_epochs):
|
||
''' train '''
|
||
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||
DIDF_Encoder.train()
|
||
DIDF_Decoder.train()
|
||
BaseFuseLayer.train()
|
||
DetailFuseLayer.train()
|
||
|
||
DIDF_Encoder.zero_grad()
|
||
DIDF_Decoder.zero_grad()
|
||
BaseFuseLayer.zero_grad()
|
||
DetailFuseLayer.zero_grad()
|
||
|
||
optimizer1.zero_grad()
|
||
optimizer2.zero_grad()
|
||
optimizer3.zero_grad()
|
||
optimizer4.zero_grad()
|
||
|
||
if epoch < epoch_gap: #Phase I
|
||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,is_sar = True)
|
||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||
|
||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||
|
||
# HuberLoss 对比CDDFUSE的MSELoss
|
||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
||
|
||
# print("mse_loss_V", mse_loss_V)
|
||
# print("mse_loss_I", mse_loss_I)
|
||
|
||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||
# print("Gradient_loss", Gradient_loss)
|
||
|
||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||
# print("loss_decomp", loss_decomp)
|
||
|
||
"""
|
||
-Huber 损失函数是一种鲁棒的回归损失函数,
|
||
结合了均方误差(MSE)和绝对误差(MAE)的优点。
|
||
其设计目的是为了减少在数据中出现异常值(outliers)时对模型训练的影响。
|
||
在传统的 MSE 损失函数中,对大的误差给予较大的惩罚,
|
||
而 MAE 则对所有误差给予同等的线性惩罚。
|
||
然而,MSE 在面对异常值时可能导致模型的不稳定,
|
||
而 MAE 则可能导致梯度消失的问题。
|
||
a-Huber 损失函数通过一个阈值 δ 来自适应地选择惩罚方式:
|
||
"""
|
||
|
||
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
||
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
||
|
||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
|
||
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
|
||
|
||
loss.backward()
|
||
nn.utils.clip_grad_norm_(
|
||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
nn.utils.clip_grad_norm_(
|
||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
optimizer1.step()
|
||
optimizer2.step()
|
||
else: #Phase II
|
||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR,is_sar = True)
|
||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||
|
||
|
||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
|
||
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
|
||
|
||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||
|
||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||
|
||
loss = fusionloss + coeff_decomp * loss_decomp
|
||
loss.backward()
|
||
nn.utils.clip_grad_norm_(
|
||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
nn.utils.clip_grad_norm_(
|
||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
nn.utils.clip_grad_norm_(
|
||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
nn.utils.clip_grad_norm_(
|
||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||
optimizer1.step()
|
||
optimizer2.step()
|
||
optimizer3.step()
|
||
optimizer4.step()
|
||
|
||
# Determine approximate time left
|
||
batches_done = epoch * len(loader['train']) + i
|
||
batches_left = num_epochs * len(loader['train']) - batches_done
|
||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||
epoch_time = time.time() - prev_time
|
||
prev_time = time.time()
|
||
if i % 1 == 0:
|
||
sys.stdout.write(
|
||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||
% (
|
||
epoch,
|
||
num_epochs,
|
||
i,
|
||
len(loader['train']),
|
||
loss.item(),
|
||
time_left,
|
||
)
|
||
)
|
||
|
||
|
||
# adjust the learning rate
|
||
|
||
scheduler1.step()
|
||
scheduler2.step()
|
||
if not epoch < epoch_gap:
|
||
scheduler3.step()
|
||
scheduler4.step()
|
||
|
||
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||
optimizer1.param_groups[0]['lr'] = 1e-6
|
||
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||
optimizer2.param_groups[0]['lr'] = 1e-6
|
||
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||
optimizer3.param_groups[0]['lr'] = 1e-6
|
||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||
optimizer4.param_groups[0]['lr'] = 1e-6
|
||
|
||
if True:
|
||
checkpoint = {
|
||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||
}
|
||
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))
|
||
|