2024-06-03 19:36:29 +08:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
'''
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
Import packages
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
'''
|
|
|
|
|
2024-11-15 11:07:25 +08:00
|
|
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
2024-06-03 19:36:29 +08:00
|
|
|
from utils.dataset import H5Dataset
|
|
|
|
import os
|
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from utils.loss import Fusionloss, cc, relative_diff_loss
|
|
|
|
import kornia
|
|
|
|
|
|
|
|
print(torch.__version__)
|
|
|
|
print(torch.cuda.is_available())
|
|
|
|
|
|
|
|
'''
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
Configure our network
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
|
|
criteria_fusion = Fusionloss()
|
|
|
|
model_str = 'PFCFuse'
|
|
|
|
|
|
|
|
# . Set the hyper-parameters for training
|
2024-10-06 16:42:42 +08:00
|
|
|
num_epochs = 60 # total epoch
|
2024-06-03 19:36:29 +08:00
|
|
|
epoch_gap = 40 # epoches of Phase I
|
|
|
|
|
|
|
|
lr = 1e-4
|
|
|
|
weight_decay = 0
|
|
|
|
batch_size = 1
|
|
|
|
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
|
|
|
# Coefficients of the loss function
|
|
|
|
coeff_mse_loss_VF = 1. # alpha1
|
|
|
|
coeff_mse_loss_IF = 1.
|
|
|
|
|
|
|
|
coeff_rmi_loss_VF = 1.
|
|
|
|
coeff_rmi_loss_IF = 1.
|
|
|
|
|
|
|
|
coeff_cos_loss_VF = 1.
|
|
|
|
coeff_cos_loss_IF = 1.
|
|
|
|
coeff_decomp = 2. # alpha2 and alpha4
|
|
|
|
coeff_tv = 5.
|
|
|
|
|
|
|
|
clip_grad_norm_value = 0.01
|
|
|
|
optim_step = 20
|
|
|
|
optim_gamma = 0.5
|
|
|
|
|
2024-10-06 16:42:42 +08:00
|
|
|
# 打印所有参数
|
|
|
|
print(f"Model: {model_str}")
|
|
|
|
print(f"Number of epochs: {num_epochs}")
|
|
|
|
print(f"Epoch gap: {epoch_gap}")
|
|
|
|
print(f"Learning rate: {lr}")
|
|
|
|
print(f"Weight decay: {weight_decay}")
|
|
|
|
print(f"Batch size: {batch_size}")
|
|
|
|
print(f"GPU number: {GPU_number}")
|
|
|
|
|
|
|
|
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
|
|
|
|
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
|
|
|
|
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
|
|
|
|
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
|
|
|
|
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
|
|
|
|
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
|
|
|
|
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
|
|
|
|
print(f"Coefficient of Total Variation loss: {coeff_tv}")
|
|
|
|
|
|
|
|
print(f"Clip gradient norm value: {clip_grad_norm_value}")
|
|
|
|
print(f"Optimization step: {optim_step}")
|
|
|
|
print(f"Optimization gamma: {optim_gamma}")
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
# Model
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
|
|
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
|
|
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
2024-11-15 11:07:25 +08:00
|
|
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
|
|
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
# optimizer, scheduler and loss function
|
|
|
|
optimizer1 = torch.optim.Adam(
|
|
|
|
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
optimizer2 = torch.optim.Adam(
|
|
|
|
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
optimizer3 = torch.optim.Adam(
|
|
|
|
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
optimizer4 = torch.optim.Adam(
|
|
|
|
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
|
|
|
|
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
|
|
|
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
|
|
|
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
|
|
|
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
|
|
|
|
|
|
|
MSELoss = nn.MSELoss()
|
|
|
|
L1Loss = nn.L1Loss()
|
|
|
|
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
|
|
|
HuberLoss = nn.HuberLoss()
|
|
|
|
|
|
|
|
# data loader
|
2024-11-14 16:02:05 +08:00
|
|
|
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"),
|
2024-06-03 19:36:29 +08:00
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=0)
|
|
|
|
|
|
|
|
loader = {'train': trainloader, }
|
|
|
|
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
|
|
|
|
|
|
|
'''
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
Train
|
|
|
|
------------------------------------------------------------------------------
|
|
|
|
'''
|
|
|
|
|
|
|
|
step = 0
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
prev_time = time.time()
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
''' train '''
|
|
|
|
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
|
|
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
|
|
|
DIDF_Encoder.train()
|
|
|
|
DIDF_Decoder.train()
|
|
|
|
BaseFuseLayer.train()
|
|
|
|
DetailFuseLayer.train()
|
|
|
|
|
|
|
|
DIDF_Encoder.zero_grad()
|
|
|
|
DIDF_Decoder.zero_grad()
|
|
|
|
BaseFuseLayer.zero_grad()
|
|
|
|
DetailFuseLayer.zero_grad()
|
|
|
|
|
|
|
|
optimizer1.zero_grad()
|
|
|
|
optimizer2.zero_grad()
|
|
|
|
optimizer3.zero_grad()
|
|
|
|
optimizer4.zero_grad()
|
|
|
|
|
|
|
|
if epoch < epoch_gap: #Phase I
|
|
|
|
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
2024-10-09 11:35:06 +08:00
|
|
|
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,is_sar = True)
|
2024-06-03 19:36:29 +08:00
|
|
|
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
|
|
|
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
|
|
|
|
|
|
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
|
|
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
|
|
|
|
|
|
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
|
|
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
|
|
|
|
|
|
|
# print("mse_loss_V", mse_loss_V)
|
|
|
|
# print("mse_loss_I", mse_loss_I)
|
|
|
|
|
|
|
|
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
|
|
|
kornia.filters.SpatialGradient()(data_VIS_hat))
|
|
|
|
# print("Gradient_loss", Gradient_loss)
|
|
|
|
|
|
|
|
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
|
|
|
# print("loss_decomp", loss_decomp)
|
|
|
|
|
|
|
|
|
|
|
|
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
|
|
|
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
|
|
|
|
|
|
|
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
|
|
|
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
|
|
|
|
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
optimizer1.step()
|
|
|
|
optimizer2.step()
|
|
|
|
else: #Phase II
|
|
|
|
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
2024-10-09 11:35:06 +08:00
|
|
|
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR,is_sar = True)
|
2024-06-03 19:36:29 +08:00
|
|
|
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
|
|
|
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
|
|
|
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
|
|
|
|
|
|
|
|
|
|
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
|
|
|
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
|
|
|
|
|
|
|
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
|
|
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
|
|
|
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
|
|
|
|
|
|
|
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
|
|
|
|
|
|
|
loss = fusionloss + coeff_decomp * loss_decomp
|
|
|
|
loss.backward()
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
nn.utils.clip_grad_norm_(
|
|
|
|
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
|
|
optimizer1.step()
|
|
|
|
optimizer2.step()
|
|
|
|
optimizer3.step()
|
|
|
|
optimizer4.step()
|
|
|
|
|
|
|
|
# Determine approximate time left
|
|
|
|
batches_done = epoch * len(loader['train']) + i
|
|
|
|
batches_left = num_epochs * len(loader['train']) - batches_done
|
|
|
|
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
|
|
|
epoch_time = time.time() - prev_time
|
|
|
|
prev_time = time.time()
|
2024-11-14 16:02:05 +08:00
|
|
|
|
|
|
|
sys.stdout.write(
|
2024-10-06 22:16:42 +08:00
|
|
|
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
|
|
|
% (
|
|
|
|
epoch,
|
|
|
|
num_epochs,
|
|
|
|
i,
|
|
|
|
len(loader['train']),
|
|
|
|
loss.item(),
|
|
|
|
time_left,
|
|
|
|
)
|
2024-11-14 16:02:05 +08:00
|
|
|
)
|
2024-10-06 22:16:42 +08:00
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
# adjust the learning rate
|
|
|
|
|
|
|
|
scheduler1.step()
|
|
|
|
scheduler2.step()
|
|
|
|
if not epoch < epoch_gap:
|
|
|
|
scheduler3.step()
|
|
|
|
scheduler4.step()
|
|
|
|
|
|
|
|
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
|
|
|
optimizer1.param_groups[0]['lr'] = 1e-6
|
|
|
|
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
|
|
|
optimizer2.param_groups[0]['lr'] = 1e-6
|
|
|
|
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
|
|
|
optimizer3.param_groups[0]['lr'] = 1e-6
|
|
|
|
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
|
|
|
optimizer4.param_groups[0]['lr'] = 1e-6
|
|
|
|
|
|
|
|
if True:
|
|
|
|
checkpoint = {
|
|
|
|
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
|
|
|
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
|
|
|
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
|
|
|
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
|
|
|
}
|
2024-10-06 16:42:18 +08:00
|
|
|
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))
|
2024-06-03 19:36:29 +08:00
|
|
|
|