diff --git a/README.md b/README.md index cd84ada..8dc2e74 100644 --- a/README.md +++ b/README.md @@ -85,11 +85,11 @@ If you want to infer with our CDDFuse and obtain the fusion results in our paper ``` python test_IVF.py ``` -for IVF and +for Infrared-Visible Fusion and ``` python test_MIF.py ``` -for MIF. +for Medical Image Fusion. The testing results will be printed in the terminal. diff --git a/dataprocessing.py b/dataprocessing.py new file mode 100644 index 0000000..1b986f0 --- /dev/null +++ b/dataprocessing.py @@ -0,0 +1,93 @@ +import os +import h5py +import numpy as np +from tqdm import tqdm +from skimage.io import imread + + +def get_img_file(file_name): + imagelist = [] + for parent, dirnames, filenames in os.walk(file_name): + for filename in filenames: + if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')): + imagelist.append(os.path.join(parent, filename)) + return imagelist + +def rgb2y(img): + y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000 + return y + +def Im2Patch(img, win, stride=1): + k = 0 + endc = img.shape[0] + endw = img.shape[1] + endh = img.shape[2] + patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] + TotalPatNum = patch.shape[1] * patch.shape[2] + Y = np.zeros([endc, win*win,TotalPatNum], np.float32) + for i in range(win): + for j in range(win): + patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride] + Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum) + k = k + 1 + return Y.reshape([endc, win, win, TotalPatNum]) + +def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10, + upper_percentile=90): + """Determine if an image is low contrast.""" + limits = np.percentile(image, [lower_percentile, upper_percentile]) + ratio = (limits[1] - limits[0]) / limits[1] + return ratio < fraction_threshold + +data_name="MSRS_train" +img_size=128 #patch size +stride=200 #patch stride + +IR_files = sorted(get_img_file(r"MSRS_train/ir")) +VIS_files = sorted(get_img_file(r"MSRS_train/vi")) + +assert len(IR_files) == len(VIS_files) +h5f = h5py.File(os.path.join('.\\data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), + 'w') +h5_ir = h5f.create_group('ir_patchs') +h5_vis = h5f.create_group('vis_patchs') +train_num=0 +for i in tqdm(range(len(IR_files))): + I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32 + I_VIS = rgb2y(I_VIS) # [1, H, W] Float32 + I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32 + + # crop + I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride) + I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12) + + for ii in range(I_IR_Patch_Group.shape[-1]): + bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii]) + bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii]) + # Determine if the contrast is low + if not (bad_IR or bad_VIS): + avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR + avl_VIS= I_VIS_Patch_Group[0,:,:,ii] + avl_IR=avl_IR[None,...] + avl_VIS=avl_VIS[None,...] + + h5_ir.create_dataset(str(train_num), data=avl_IR, + dtype=avl_IR.dtype, shape=avl_IR.shape) + h5_vis.create_dataset(str(train_num), data=avl_VIS, + dtype=avl_VIS.dtype, shape=avl_VIS.shape) + train_num += 1 + +h5f.close() + +with h5py.File(os.path.join('data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f: + for key in f.keys(): + print(f[key], key, f[key].name) + + + + + + + diff --git a/models/CDDFuse_IVF.pth b/models/CDDFuse_IVF.pth index 12f497c..d279422 100644 Binary files a/models/CDDFuse_IVF.pth and b/models/CDDFuse_IVF.pth differ diff --git a/models/CDDFuse_MIF.pth b/models/CDDFuse_MIF.pth index a5fa411..072e39a 100644 Binary files a/models/CDDFuse_MIF.pth and b/models/CDDFuse_MIF.pth differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..9ca0938 --- /dev/null +++ b/train.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +''' +------------------------------------------------------------------------------ +Import packages +------------------------------------------------------------------------------ +''' + +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from utils.dataset import H5Dataset +import os +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' +import sys +import time +import datetime +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from utils.loss import Fusionloss, cc +import kornia + + + +''' +------------------------------------------------------------------------------ +Configure our network +------------------------------------------------------------------------------ +''' + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +criteria_fusion = Fusionloss() +model_str = 'CDDFuse' + +# . Set the hyper-parameters for training +num_epochs = 120 # total epoch +epoch_gap = 40 # epoches of Phase I + +lr = 1e-4 +weight_decay = 0 +batch_size = 8 +GPU_number = os.environ['CUDA_VISIBLE_DEVICES'] +# Coefficients of the loss function +coeff_mse_loss_VF = 1. # alpha1 +coeff_mse_loss_IF = 1. +coeff_decomp = 2. # alpha2 and alpha4 +coeff_tv = 5. + +clip_grad_norm_value = 0.01 +optim_step = 20 +optim_gamma = 0.5 + + +# Model +device = 'cuda' if torch.cuda.is_available() else 'cpu' +DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) +DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) + +# optimizer, scheduler and loss function +optimizer1 = torch.optim.Adam( + DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay) +optimizer2 = torch.optim.Adam( + DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay) +optimizer3 = torch.optim.Adam( + BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) +optimizer4 = torch.optim.Adam( + DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) + +scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma) +scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma) +scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma) +scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma) + +MSELoss = nn.MSELoss() +L1Loss = nn.L1Loss() +Loss_ssim = kornia.losses.SSIM(11, reduction='mean') + + +# data loader +trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"), + batch_size=batch_size, + shuffle=True, + num_workers=0) + +loader = {'train': trainloader, } +timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M") + +''' +------------------------------------------------------------------------------ +Train +------------------------------------------------------------------------------ +''' + +step = 0 +torch.backends.cudnn.benchmark = True +prev_time = time.time() + +for epoch in range(num_epochs): + ''' train ''' + for i, (data_VIS, data_IR) in enumerate(loader['train']): + data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() + DIDF_Encoder.train() + DIDF_Decoder.train() + BaseFuseLayer.train() + DetailFuseLayer.train() + + DIDF_Encoder.zero_grad() + DIDF_Decoder.zero_grad() + BaseFuseLayer.zero_grad() + DetailFuseLayer.zero_grad() + + optimizer1.zero_grad() + optimizer2.zero_grad() + optimizer3.zero_grad() + optimizer4.zero_grad() + + if epoch < epoch_gap: #Phase I + feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) + feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR) + data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) + data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D) + + cc_loss_B = cc(feature_V_B, feature_I_B) + cc_loss_D = cc(feature_V_D, feature_I_D) + mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat) + mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat) + + Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS), + kornia.filters.SpatialGradient()(data_VIS_hat)) + + loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) + + loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \ + mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + + loss.backward() + nn.utils.clip_grad_norm_( + DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + optimizer1.step() + optimizer2.step() + else: #Phase II + feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS) + feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR) + feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B) + feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D) + data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) + + + mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse) + mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse) + + cc_loss_B = cc(feature_V_B, feature_I_B) + cc_loss_D = cc(feature_V_D, feature_I_D) + loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B) + fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse) + + loss = fusionloss + coeff_decomp * loss_decomp + loss.backward() + nn.utils.clip_grad_norm_( + DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + optimizer1.step() + optimizer2.step() + optimizer3.step() + optimizer4.step() + + # Determine approximate time left + batches_done = epoch * len(loader['train']) + i + batches_left = num_epochs * len(loader['train']) - batches_done + time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) + prev_time = time.time() + sys.stdout.write( + "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" + % ( + epoch, + num_epochs, + i, + len(loader['train']), + loss.item(), + time_left, + ) + ) + + # adjust the learning rate + + scheduler1.step() + scheduler2.step() + if not epoch < epoch_gap: + scheduler3.step() + scheduler4.step() + + if optimizer1.param_groups[0]['lr'] <= 1e-6: + optimizer1.param_groups[0]['lr'] = 1e-6 + if optimizer2.param_groups[0]['lr'] <= 1e-6: + optimizer2.param_groups[0]['lr'] = 1e-6 + if optimizer3.param_groups[0]['lr'] <= 1e-6: + optimizer3.param_groups[0]['lr'] = 1e-6 + if optimizer4.param_groups[0]['lr'] <= 1e-6: + optimizer4.param_groups[0]['lr'] = 1e-6 + +if True: + checkpoint = { + 'DIDF_Encoder': DIDF_Encoder.state_dict(), + 'DIDF_Decoder': DIDF_Decoder.state_dict(), + 'BaseFuseLayer': BaseFuseLayer.state_dict(), + 'DetailFuseLayer': DetailFuseLayer.state_dict(), + } + torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth')) + + diff --git a/utils/__pycache__/Evaluator.cpython-38.pyc b/utils/__pycache__/Evaluator.cpython-38.pyc index 9933cef..f9a696e 100644 Binary files a/utils/__pycache__/Evaluator.cpython-38.pyc and b/utils/__pycache__/Evaluator.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset.cpython-38.pyc b/utils/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000..4c9eab3 Binary files /dev/null and b/utils/__pycache__/dataset.cpython-38.pyc differ diff --git a/utils/__pycache__/img_read_save.cpython-38.pyc b/utils/__pycache__/img_read_save.cpython-38.pyc index d3e4b27..241519a 100644 Binary files a/utils/__pycache__/img_read_save.cpython-38.pyc and b/utils/__pycache__/img_read_save.cpython-38.pyc differ diff --git a/utils/__pycache__/loss.cpython-38.pyc b/utils/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000..aa0f00e Binary files /dev/null and b/utils/__pycache__/loss.cpython-38.pyc differ diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..666175f --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,22 @@ +import torch.utils.data as Data +import h5py +import numpy as np +import torch + +class H5Dataset(Data.Dataset): + def __init__(self, h5file_path): + self.h5file_path = h5file_path + h5f = h5py.File(h5file_path, 'r') + self.keys = list(h5f['ir_patchs'].keys()) + h5f.close() + + def __len__(self): + return len(self.keys) + + def __getitem__(self, index): + h5f = h5py.File(self.h5file_path, 'r') + key = self.keys[index] + IR = np.array(h5f['ir_patchs'][key]) + VIS = np.array(h5f['vis_patchs'][key]) + h5f.close() + return torch.Tensor(VIS), torch.Tensor(IR) \ No newline at end of file diff --git a/utils/loss.py b/utils/loss.py index bba08db..acb17b8 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np + class Fusionloss(nn.Module): def __init__(self): @@ -38,3 +38,17 @@ class Sobelxy(nn.Module): sobelx=F.conv2d(x, self.weightx, padding=1) sobely=F.conv2d(x, self.weighty, padding=1) return torch.abs(sobelx)+torch.abs(sobely) + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 ** + 2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() \ No newline at end of file