main
This commit is contained in:
parent
0c2350bc2c
commit
3fe9a38165
@ -85,11 +85,11 @@ If you want to infer with our CDDFuse and obtain the fusion results in our paper
|
||||
```
|
||||
python test_IVF.py
|
||||
```
|
||||
for IVF and
|
||||
for Infrared-Visible Fusion and
|
||||
```
|
||||
python test_MIF.py
|
||||
```
|
||||
for MIF.
|
||||
for Medical Image Fusion.
|
||||
|
||||
The testing results will be printed in the terminal.
|
||||
|
||||
|
93
dataprocessing.py
Normal file
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage.io import imread
|
||||
|
||||
|
||||
def get_img_file(file_name):
|
||||
imagelist = []
|
||||
for parent, dirnames, filenames in os.walk(file_name):
|
||||
for filename in filenames:
|
||||
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||
imagelist.append(os.path.join(parent, filename))
|
||||
return imagelist
|
||||
|
||||
def rgb2y(img):
|
||||
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||
return y
|
||||
|
||||
def Im2Patch(img, win, stride=1):
|
||||
k = 0
|
||||
endc = img.shape[0]
|
||||
endw = img.shape[1]
|
||||
endh = img.shape[2]
|
||||
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||
for i in range(win):
|
||||
for j in range(win):
|
||||
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||
k = k + 1
|
||||
return Y.reshape([endc, win, win, TotalPatNum])
|
||||
|
||||
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||
upper_percentile=90):
|
||||
"""Determine if an image is low contrast."""
|
||||
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||
ratio = (limits[1] - limits[0]) / limits[1]
|
||||
return ratio < fraction_threshold
|
||||
|
||||
data_name="MSRS_train"
|
||||
img_size=128 #patch size
|
||||
stride=200 #patch stride
|
||||
|
||||
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
||||
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
||||
|
||||
assert len(IR_files) == len(VIS_files)
|
||||
h5f = h5py.File(os.path.join('.\\data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||
'w')
|
||||
h5_ir = h5f.create_group('ir_patchs')
|
||||
h5_vis = h5f.create_group('vis_patchs')
|
||||
train_num=0
|
||||
for i in tqdm(range(len(IR_files))):
|
||||
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||
|
||||
# crop
|
||||
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||
|
||||
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||
# Determine if the contrast is low
|
||||
if not (bad_IR or bad_VIS):
|
||||
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||
avl_IR=avl_IR[None,...]
|
||||
avl_VIS=avl_VIS[None,...]
|
||||
|
||||
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||
train_num += 1
|
||||
|
||||
h5f.close()
|
||||
|
||||
with h5py.File(os.path.join('data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||
for key in f.keys():
|
||||
print(f[key], key, f[key].name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
219
train.py
Normal file
219
train.py
Normal file
@ -0,0 +1,219 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Import packages
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
from utils.dataset import H5Dataset
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.loss import Fusionloss, cc
|
||||
import kornia
|
||||
|
||||
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Configure our network
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
criteria_fusion = Fusionloss()
|
||||
model_str = 'CDDFuse'
|
||||
|
||||
# . Set the hyper-parameters for training
|
||||
num_epochs = 120 # total epoch
|
||||
epoch_gap = 40 # epoches of Phase I
|
||||
|
||||
lr = 1e-4
|
||||
weight_decay = 0
|
||||
batch_size = 8
|
||||
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||||
# Coefficients of the loss function
|
||||
coeff_mse_loss_VF = 1. # alpha1
|
||||
coeff_mse_loss_IF = 1.
|
||||
coeff_decomp = 2. # alpha2 and alpha4
|
||||
coeff_tv = 5.
|
||||
|
||||
clip_grad_norm_value = 0.01
|
||||
optim_step = 20
|
||||
optim_gamma = 0.5
|
||||
|
||||
|
||||
# Model
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
# optimizer, scheduler and loss function
|
||||
optimizer1 = torch.optim.Adam(
|
||||
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer2 = torch.optim.Adam(
|
||||
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer3 = torch.optim.Adam(
|
||||
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer4 = torch.optim.Adam(
|
||||
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||
|
||||
MSELoss = nn.MSELoss()
|
||||
L1Loss = nn.L1Loss()
|
||||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||
|
||||
|
||||
# data loader
|
||||
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
|
||||
loader = {'train': trainloader, }
|
||||
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Train
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
step = 0
|
||||
torch.backends.cudnn.benchmark = True
|
||||
prev_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
''' train '''
|
||||
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
DIDF_Encoder.train()
|
||||
DIDF_Decoder.train()
|
||||
BaseFuseLayer.train()
|
||||
DetailFuseLayer.train()
|
||||
|
||||
DIDF_Encoder.zero_grad()
|
||||
DIDF_Decoder.zero_grad()
|
||||
BaseFuseLayer.zero_grad()
|
||||
DetailFuseLayer.zero_grad()
|
||||
|
||||
optimizer1.zero_grad()
|
||||
optimizer2.zero_grad()
|
||||
optimizer3.zero_grad()
|
||||
optimizer4.zero_grad()
|
||||
|
||||
if epoch < epoch_gap: #Phase I
|
||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat)
|
||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat)
|
||||
|
||||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||
|
||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||
|
||||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
else: #Phase II
|
||||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
|
||||
|
||||
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
||||
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||
|
||||
loss = fusionloss + coeff_decomp * loss_decomp
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
optimizer3.step()
|
||||
optimizer4.step()
|
||||
|
||||
# Determine approximate time left
|
||||
batches_done = epoch * len(loader['train']) + i
|
||||
batches_left = num_epochs * len(loader['train']) - batches_done
|
||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||
prev_time = time.time()
|
||||
sys.stdout.write(
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||
% (
|
||||
epoch,
|
||||
num_epochs,
|
||||
i,
|
||||
len(loader['train']),
|
||||
loss.item(),
|
||||
time_left,
|
||||
)
|
||||
)
|
||||
|
||||
# adjust the learning rate
|
||||
|
||||
scheduler1.step()
|
||||
scheduler2.step()
|
||||
if not epoch < epoch_gap:
|
||||
scheduler3.step()
|
||||
scheduler4.step()
|
||||
|
||||
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer1.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer2.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||
|
||||
if True:
|
||||
checkpoint = {
|
||||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth'))
|
||||
|
||||
|
Binary file not shown.
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
22
utils/dataset.py
Normal file
22
utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch.utils.data as Data
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class H5Dataset(Data.Dataset):
|
||||
def __init__(self, h5file_path):
|
||||
self.h5file_path = h5file_path
|
||||
h5f = h5py.File(h5file_path, 'r')
|
||||
self.keys = list(h5f['ir_patchs'].keys())
|
||||
h5f.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index):
|
||||
h5f = h5py.File(self.h5file_path, 'r')
|
||||
key = self.keys[index]
|
||||
IR = np.array(h5f['ir_patchs'][key])
|
||||
VIS = np.array(h5f['vis_patchs'][key])
|
||||
h5f.close()
|
||||
return torch.Tensor(VIS), torch.Tensor(IR)
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Fusionloss(nn.Module):
|
||||
def __init__(self):
|
||||
@ -38,3 +38,17 @@ class Sobelxy(nn.Module):
|
||||
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||
return torch.abs(sobelx)+torch.abs(sobely)
|
||||
|
||||
|
||||
def cc(img1, img2):
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||||
N, C, _, _ = img1.shape
|
||||
img1 = img1.reshape(N, C, -1)
|
||||
img2 = img2.reshape(N, C, -1)
|
||||
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||||
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||||
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||||
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||||
cc = torch.clamp(cc, -1., 1.)
|
||||
return cc.mean()
|
Loading…
Reference in New Issue
Block a user