pfcfuse/train.py
zjut fa8106838e feat(model): 重构模型并添加新功能
- 重新组织了模型结构,增加了新的特征融合模块
- 添加了深度可分离卷积块和新的细节特征提取模块
- 更新了数据处理流程,使用了新的数据集路径
- 调整了训练参数,增加了训练轮次和学习率- 优化了损失函数,使用了Huber损失替代MSE损失
2024-11-12 10:37:56 +08:00

277 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
'''
------------------------------------------------------------------------------
Import packages
------------------------------------------------------------------------------
'''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
DetailFeatureFusion
from utils.dataset import H5Dataset
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.loss import Fusionloss, cc, relative_diff_loss
import kornia
print(torch.__version__)
print(torch.cuda.is_available())
'''
------------------------------------------------------------------------------
Configure our network
------------------------------------------------------------------------------
'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
criteria_fusion = Fusionloss()
model_str = 'WhaiFuse'
# . Set the hyper-parameters for training
num_epochs = 120 # total epoch
epoch_gap = 80 # epoches of Phase I
lr = 1e-4
weight_decay = 0
batch_size = 1
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
# Coefficients of the loss function
coeff_mse_loss_VF = 1. # alpha1
coeff_mse_loss_IF = 1.
coeff_rmi_loss_VF = 1.
coeff_rmi_loss_IF = 1.
coeff_cos_loss_VF = 1.
coeff_cos_loss_IF = 1.
coeff_decomp = 2. # alpha2 and alpha4
coeff_tv = 5.
clip_grad_norm_value = 0.01
optim_step = 20
optim_gamma = 0.5
# 打印所有参数
print(f"Model: {model_str}")
print(f"Number of epochs: {num_epochs}")
print(f"Epoch gap: {epoch_gap}")
print(f"Learning rate: {lr}")
print(f"Weight decay: {weight_decay}")
print(f"Batch size: {batch_size}")
print(f"GPU number: {GPU_number}")
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
print(f"Coefficient of Total Variation loss: {coeff_tv}")
print(f"Clip gradient norm value: {clip_grad_norm_value}")
print(f"Optimization step: {optim_step}")
print(f"Optimization gamma: {optim_gamma}")
# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
# optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam(
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.Adam(
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer3 = torch.optim.Adam(
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
optimizer4 = torch.optim.Adam(
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
MSELoss = nn.MSELoss()
L1Loss = nn.L1Loss()
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
HuberLoss = nn.HuberLoss()
# data loader
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_128_stride_200.h5"),
batch_size=batch_size,
shuffle=True,
num_workers=0)
loader = {'train': trainloader, }
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
'''
------------------------------------------------------------------------------
Train
------------------------------------------------------------------------------
'''
step = 0
torch.backends.cudnn.benchmark = True
prev_time = time.time()
start = time.time()
for epoch in range(num_epochs):
''' train '''
for i, (data_VIS, data_IR) in enumerate(loader['train']):
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
DIDF_Encoder.train()
DIDF_Decoder.train()
BaseFuseLayer.train()
DetailFuseLayer.train()
DIDF_Encoder.zero_grad()
DIDF_Decoder.zero_grad()
BaseFuseLayer.zero_grad()
DetailFuseLayer.zero_grad()
optimizer1.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
optimizer4.zero_grad()
if epoch < epoch_gap: #Phase I
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,is_sar = True)
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
# HuberLoss 对比CDDFUSE的MSELoss
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
# print("mse_loss_V", mse_loss_V)
# print("mse_loss_I", mse_loss_I)
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
kornia.filters.SpatialGradient()(data_VIS_hat))
# print("Gradient_loss", Gradient_loss)
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
# print("loss_decomp", loss_decomp)
"""
-Huber 损失函数是一种鲁棒的回归损失函数,
结合了均方误差MSE和绝对误差MAE的优点。
其设计目的是为了减少在数据中出现异常值outliers时对模型训练的影响。
在传统的 MSE 损失函数中,对大的误差给予较大的惩罚,
而 MAE 则对所有误差给予同等的线性惩罚。
然而MSE 在面对异常值时可能导致模型的不稳定,
而 MAE 则可能导致梯度消失的问题。
a-Huber 损失函数通过一个阈值 δ 来自适应地选择惩罚方式:
"""
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
else: #Phase II
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR,is_sar = True)
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
loss = fusionloss + coeff_decomp * loss_decomp
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
optimizer3.step()
optimizer4.step()
# Determine approximate time left
batches_done = epoch * len(loader['train']) + i
batches_left = num_epochs * len(loader['train']) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
epoch_time = time.time() - prev_time
prev_time = time.time()
if i % 1 == 0:
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
time_left,
)
)
# adjust the learning rate
scheduler1.step()
scheduler2.step()
if not epoch < epoch_gap:
scheduler3.step()
scheduler4.step()
if optimizer1.param_groups[0]['lr'] <= 1e-6:
optimizer1.param_groups[0]['lr'] = 1e-6
if optimizer2.param_groups[0]['lr'] <= 1e-6:
optimizer2.param_groups[0]['lr'] = 1e-6
if optimizer3.param_groups[0]['lr'] <= 1e-6:
optimizer3.param_groups[0]['lr'] = 1e-6
if optimizer4.param_groups[0]['lr'] <= 1e-6:
optimizer4.param_groups[0]['lr'] = 1e-6
if True:
checkpoint = {
'DIDF_Encoder': DIDF_Encoder.state_dict(),
'DIDF_Decoder': DIDF_Decoder.state_dict(),
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))