main
This commit is contained in:
parent
0c2350bc2c
commit
3fe9a38165
@ -85,11 +85,11 @@ If you want to infer with our CDDFuse and obtain the fusion results in our paper
|
|||||||
```
|
```
|
||||||
python test_IVF.py
|
python test_IVF.py
|
||||||
```
|
```
|
||||||
for IVF and
|
for Infrared-Visible Fusion and
|
||||||
```
|
```
|
||||||
python test_MIF.py
|
python test_MIF.py
|
||||||
```
|
```
|
||||||
for MIF.
|
for Medical Image Fusion.
|
||||||
|
|
||||||
The testing results will be printed in the terminal.
|
The testing results will be printed in the terminal.
|
||||||
|
|
||||||
|
93
dataprocessing.py
Normal file
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from skimage.io import imread
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_file(file_name):
|
||||||
|
imagelist = []
|
||||||
|
for parent, dirnames, filenames in os.walk(file_name):
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||||
|
imagelist.append(os.path.join(parent, filename))
|
||||||
|
return imagelist
|
||||||
|
|
||||||
|
def rgb2y(img):
|
||||||
|
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||||
|
return y
|
||||||
|
|
||||||
|
def Im2Patch(img, win, stride=1):
|
||||||
|
k = 0
|
||||||
|
endc = img.shape[0]
|
||||||
|
endw = img.shape[1]
|
||||||
|
endh = img.shape[2]
|
||||||
|
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||||
|
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||||
|
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||||
|
for i in range(win):
|
||||||
|
for j in range(win):
|
||||||
|
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||||
|
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||||
|
k = k + 1
|
||||||
|
return Y.reshape([endc, win, win, TotalPatNum])
|
||||||
|
|
||||||
|
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||||
|
upper_percentile=90):
|
||||||
|
"""Determine if an image is low contrast."""
|
||||||
|
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||||
|
ratio = (limits[1] - limits[0]) / limits[1]
|
||||||
|
return ratio < fraction_threshold
|
||||||
|
|
||||||
|
data_name="MSRS_train"
|
||||||
|
img_size=128 #patch size
|
||||||
|
stride=200 #patch stride
|
||||||
|
|
||||||
|
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
||||||
|
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
||||||
|
|
||||||
|
assert len(IR_files) == len(VIS_files)
|
||||||
|
h5f = h5py.File(os.path.join('.\\data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||||
|
'w')
|
||||||
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
|
train_num=0
|
||||||
|
for i in tqdm(range(len(IR_files))):
|
||||||
|
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||||
|
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||||
|
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||||
|
|
||||||
|
# crop
|
||||||
|
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||||
|
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||||
|
|
||||||
|
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||||
|
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||||
|
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||||
|
# Determine if the contrast is low
|
||||||
|
if not (bad_IR or bad_VIS):
|
||||||
|
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||||
|
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||||
|
avl_IR=avl_IR[None,...]
|
||||||
|
avl_VIS=avl_VIS[None,...]
|
||||||
|
|
||||||
|
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||||
|
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||||
|
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||||
|
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||||
|
train_num += 1
|
||||||
|
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
with h5py.File(os.path.join('data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
219
train.py
Normal file
219
train.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Import packages
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
from utils.dataset import H5Dataset
|
||||||
|
import os
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from utils.loss import Fusionloss, cc
|
||||||
|
import kornia
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Configure our network
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
criteria_fusion = Fusionloss()
|
||||||
|
model_str = 'CDDFuse'
|
||||||
|
|
||||||
|
# . Set the hyper-parameters for training
|
||||||
|
num_epochs = 120 # total epoch
|
||||||
|
epoch_gap = 40 # epoches of Phase I
|
||||||
|
|
||||||
|
lr = 1e-4
|
||||||
|
weight_decay = 0
|
||||||
|
batch_size = 8
|
||||||
|
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||||||
|
# Coefficients of the loss function
|
||||||
|
coeff_mse_loss_VF = 1. # alpha1
|
||||||
|
coeff_mse_loss_IF = 1.
|
||||||
|
coeff_decomp = 2. # alpha2 and alpha4
|
||||||
|
coeff_tv = 5.
|
||||||
|
|
||||||
|
clip_grad_norm_value = 0.01
|
||||||
|
optim_step = 20
|
||||||
|
optim_gamma = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
# optimizer, scheduler and loss function
|
||||||
|
optimizer1 = torch.optim.Adam(
|
||||||
|
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer2 = torch.optim.Adam(
|
||||||
|
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer3 = torch.optim.Adam(
|
||||||
|
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer4 = torch.optim.Adam(
|
||||||
|
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
|
||||||
|
MSELoss = nn.MSELoss()
|
||||||
|
L1Loss = nn.L1Loss()
|
||||||
|
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||||
|
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0)
|
||||||
|
|
||||||
|
loader = {'train': trainloader, }
|
||||||
|
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Train
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
prev_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
''' train '''
|
||||||
|
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
DIDF_Encoder.train()
|
||||||
|
DIDF_Decoder.train()
|
||||||
|
BaseFuseLayer.train()
|
||||||
|
DetailFuseLayer.train()
|
||||||
|
|
||||||
|
DIDF_Encoder.zero_grad()
|
||||||
|
DIDF_Decoder.zero_grad()
|
||||||
|
BaseFuseLayer.zero_grad()
|
||||||
|
DetailFuseLayer.zero_grad()
|
||||||
|
|
||||||
|
optimizer1.zero_grad()
|
||||||
|
optimizer2.zero_grad()
|
||||||
|
optimizer3.zero_grad()
|
||||||
|
optimizer4.zero_grad()
|
||||||
|
|
||||||
|
if epoch < epoch_gap: #Phase I
|
||||||
|
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||||
|
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||||
|
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||||
|
|
||||||
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat)
|
||||||
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat)
|
||||||
|
|
||||||
|
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||||
|
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||||
|
|
||||||
|
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||||
|
|
||||||
|
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||||
|
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
optimizer1.step()
|
||||||
|
optimizer2.step()
|
||||||
|
else: #Phase II
|
||||||
|
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||||
|
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
|
||||||
|
|
||||||
|
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
||||||
|
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
||||||
|
|
||||||
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
|
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||||
|
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||||
|
|
||||||
|
loss = fusionloss + coeff_decomp * loss_decomp
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
optimizer1.step()
|
||||||
|
optimizer2.step()
|
||||||
|
optimizer3.step()
|
||||||
|
optimizer4.step()
|
||||||
|
|
||||||
|
# Determine approximate time left
|
||||||
|
batches_done = epoch * len(loader['train']) + i
|
||||||
|
batches_left = num_epochs * len(loader['train']) - batches_done
|
||||||
|
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||||
|
prev_time = time.time()
|
||||||
|
sys.stdout.write(
|
||||||
|
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||||
|
% (
|
||||||
|
epoch,
|
||||||
|
num_epochs,
|
||||||
|
i,
|
||||||
|
len(loader['train']),
|
||||||
|
loss.item(),
|
||||||
|
time_left,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# adjust the learning rate
|
||||||
|
|
||||||
|
scheduler1.step()
|
||||||
|
scheduler2.step()
|
||||||
|
if not epoch < epoch_gap:
|
||||||
|
scheduler3.step()
|
||||||
|
scheduler4.step()
|
||||||
|
|
||||||
|
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer1.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer2.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||||
|
|
||||||
|
if True:
|
||||||
|
checkpoint = {
|
||||||
|
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||||
|
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||||||
|
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||||
|
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth'))
|
||||||
|
|
||||||
|
|
Binary file not shown.
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
22
utils/dataset.py
Normal file
22
utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch.utils.data as Data
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class H5Dataset(Data.Dataset):
|
||||||
|
def __init__(self, h5file_path):
|
||||||
|
self.h5file_path = h5file_path
|
||||||
|
h5f = h5py.File(h5file_path, 'r')
|
||||||
|
self.keys = list(h5f['ir_patchs'].keys())
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.keys)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
h5f = h5py.File(self.h5file_path, 'r')
|
||||||
|
key = self.keys[index]
|
||||||
|
IR = np.array(h5f['ir_patchs'][key])
|
||||||
|
VIS = np.array(h5f['vis_patchs'][key])
|
||||||
|
h5f.close()
|
||||||
|
return torch.Tensor(VIS), torch.Tensor(IR)
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class Fusionloss(nn.Module):
|
class Fusionloss(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -38,3 +38,17 @@ class Sobelxy(nn.Module):
|
|||||||
sobelx=F.conv2d(x, self.weightx, padding=1)
|
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||||
sobely=F.conv2d(x, self.weighty, padding=1)
|
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||||
return torch.abs(sobelx)+torch.abs(sobely)
|
return torch.abs(sobelx)+torch.abs(sobely)
|
||||||
|
|
||||||
|
|
||||||
|
def cc(img1, img2):
|
||||||
|
eps = torch.finfo(torch.float32).eps
|
||||||
|
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||||||
|
N, C, _, _ = img1.shape
|
||||||
|
img1 = img1.reshape(N, C, -1)
|
||||||
|
img2 = img2.reshape(N, C, -1)
|
||||||
|
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||||||
|
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||||||
|
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||||||
|
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||||||
|
cc = torch.clamp(cc, -1., 1.)
|
||||||
|
return cc.mean()
|
Loading…
Reference in New Issue
Block a user