diff --git a/.idea/CDDFuse.iml b/.idea/CDDFuse.iml index 925622b..b6bdc52 100644 --- a/.idea/CDDFuse.iml +++ b/.idea/CDDFuse.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml index a1c367f..c87715f 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + @@ -12,7 +12,7 @@ - + @@ -44,6 +44,20 @@ + + + + + + + + + + + + + + diff --git a/.idea/misc.xml b/.idea/misc.xml index 21ebe7b..91adbbe 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,5 +1,8 @@ + + - + \ No newline at end of file diff --git a/data/MSRS_train_imgsize_128_stride_200.h5 b/data/MSRS_train_imgsize_128_stride_200.h5 new file mode 100644 index 0000000..e69de29 diff --git a/dataprocessing.py b/dataprocessing.py index 1b986f0..3b2d725 100644 --- a/dataprocessing.py +++ b/dataprocessing.py @@ -12,7 +12,7 @@ def get_img_file(file_name): if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')): imagelist.append(os.path.join(parent, filename)) return imagelist - + def rgb2y(img): y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000 return y @@ -43,12 +43,12 @@ data_name="MSRS_train" img_size=128 #patch size stride=200 #patch stride -IR_files = sorted(get_img_file(r"MSRS_train/ir")) -VIS_files = sorted(get_img_file(r"MSRS_train/vi")) +IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir")) +VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi")) assert len(IR_files) == len(VIS_files) -h5f = h5py.File(os.path.join('.\\data', - data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), +h5f = h5py.File(os.path.join('./data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), 'w') h5_ir = h5f.create_group('ir_patchs') h5_vis = h5f.create_group('vis_patchs') @@ -57,11 +57,11 @@ for i in tqdm(range(len(IR_files))): I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32 I_VIS = rgb2y(I_VIS) # [1, H, W] Float32 I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32 - - # crop + + # crop I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride) I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12) - + for ii in range(I_IR_Patch_Group.shape[-1]): bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii]) bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii]) @@ -72,22 +72,22 @@ for i in tqdm(range(len(IR_files))): avl_IR=avl_IR[None,...] avl_VIS=avl_VIS[None,...] - h5_ir.create_dataset(str(train_num), data=avl_IR, + h5_ir.create_dataset(str(train_num), data=avl_IR, dtype=avl_IR.dtype, shape=avl_IR.shape) - h5_vis.create_dataset(str(train_num), data=avl_VIS, + h5_vis.create_dataset(str(train_num), data=avl_VIS, dtype=avl_VIS.dtype, shape=avl_VIS.shape) - train_num += 1 + train_num += 1 h5f.close() with h5py.File(os.path.join('data', data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f: for key in f.keys(): - print(f[key], key, f[key].name) - - - + print(f[key], key, f[key].name) + + + + - diff --git a/dataprocessing_sar.py b/dataprocessing_sar.py new file mode 100644 index 0000000..3b2d725 --- /dev/null +++ b/dataprocessing_sar.py @@ -0,0 +1,93 @@ +import os +import h5py +import numpy as np +from tqdm import tqdm +from skimage.io import imread + + +def get_img_file(file_name): + imagelist = [] + for parent, dirnames, filenames in os.walk(file_name): + for filename in filenames: + if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')): + imagelist.append(os.path.join(parent, filename)) + return imagelist + +def rgb2y(img): + y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000 + return y + +def Im2Patch(img, win, stride=1): + k = 0 + endc = img.shape[0] + endw = img.shape[1] + endh = img.shape[2] + patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] + TotalPatNum = patch.shape[1] * patch.shape[2] + Y = np.zeros([endc, win*win,TotalPatNum], np.float32) + for i in range(win): + for j in range(win): + patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride] + Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum) + k = k + 1 + return Y.reshape([endc, win, win, TotalPatNum]) + +def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10, + upper_percentile=90): + """Determine if an image is low contrast.""" + limits = np.percentile(image, [lower_percentile, upper_percentile]) + ratio = (limits[1] - limits[0]) / limits[1] + return ratio < fraction_threshold + +data_name="MSRS_train" +img_size=128 #patch size +stride=200 #patch stride + +IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir")) +VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi")) + +assert len(IR_files) == len(VIS_files) +h5f = h5py.File(os.path.join('./data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), + 'w') +h5_ir = h5f.create_group('ir_patchs') +h5_vis = h5f.create_group('vis_patchs') +train_num=0 +for i in tqdm(range(len(IR_files))): + I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32 + I_VIS = rgb2y(I_VIS) # [1, H, W] Float32 + I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32 + + # crop + I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride) + I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12) + + for ii in range(I_IR_Patch_Group.shape[-1]): + bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii]) + bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii]) + # Determine if the contrast is low + if not (bad_IR or bad_VIS): + avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR + avl_VIS= I_VIS_Patch_Group[0,:,:,ii] + avl_IR=avl_IR[None,...] + avl_VIS=avl_VIS[None,...] + + h5_ir.create_dataset(str(train_num), data=avl_IR, + dtype=avl_IR.dtype, shape=avl_IR.shape) + h5_vis.create_dataset(str(train_num), data=avl_VIS, + dtype=avl_VIS.dtype, shape=avl_VIS.shape) + train_num += 1 + +h5f.close() + +with h5py.File(os.path.join('data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f: + for key in f.keys(): + print(f[key], key, f[key].name) + + + + + + + diff --git a/net.py b/net.py index 8353a2c..151fb82 100644 --- a/net.py +++ b/net.py @@ -41,8 +41,16 @@ class DropPath(nn.Module): class AttentionBase(nn.Module): + """ + 一个基础的多头注意力机制类。 + + 参数: + dim (int): 输入和输出的特征维度。 + num_heads (int, 可选): 注意力头的数量,默认为8。 + qkv_bias (bool, 可选): 是否为QKV投影层添加偏差,默认为False。 + """ def __init__(self, - dim, + dim, num_heads=8, qkv_bias=False,): super(AttentionBase, self).__init__() @@ -54,6 +62,15 @@ class AttentionBase(nn.Module): self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias) def forward(self, x): + """ + 定义了输入数据x通过多头注意力机制的前向传播过程。 + + 参数: + x (Tensor): 输入的特征张量,形状为[batch_size, dim, height, width]。 + + 返回: + Tensor: 输出的特征张量,形状为[batch_size, dim, height, width]。 + """ # [batch_size, num_patches + 1, total_embed_dim] b, c, h, w = x.shape qkv = self.qkv2(self.qkv1(x)) @@ -78,14 +95,15 @@ class AttentionBase(nn.Module): out = self.proj(out) return out - + + class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ - def __init__(self, - in_features, - hidden_features=None, + def __init__(self, + in_features, + hidden_features=None, ffn_expansion_factor = 2, bias = False): super().__init__() @@ -110,7 +128,7 @@ class BaseFeatureExtraction(nn.Module): def __init__(self, dim, num_heads, - ffn_expansion_factor=1., + ffn_expansion_factor=1., qkv_bias=False,): super(BaseFeatureExtraction, self).__init__() self.norm1 = LayerNorm(dim, 'WithBias') @@ -353,7 +371,7 @@ class Restormer_Encoder(nn.Module): bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2]) self.detailFeature = DetailFeatureExtraction() - + def forward(self, inp_img): inp_enc_level1 = self.patch_embed(inp_img) out_enc_level1 = self.encoder_level1(inp_enc_level1) @@ -383,7 +401,7 @@ class Restormer_Decoder(nn.Module): nn.LeakyReLU(), nn.Conv2d(int(dim)//2, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),) - self.sigmoid = nn.Sigmoid() + self.sigmoid = nn.Sigmoid() def forward(self, inp_img, base_feature, detail_feature): out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) out_enc_level0 = self.reduce_channel(out_enc_level0) @@ -393,7 +411,7 @@ class Restormer_Decoder(nn.Module): else: out_enc_level1 = self.output(out_enc_level1) return self.sigmoid(out_enc_level1), out_enc_level0 - + if __name__ == '__main__': height = 128 width = 128 diff --git a/test_IVF.py b/test_IVF.py index f17df6f..9041e85 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -1,3 +1,5 @@ +import cv2 + from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction import os import numpy as np @@ -12,11 +14,11 @@ logging.basicConfig(level=logging.CRITICAL) os.environ["CUDA_VISIBLE_DEVICES"] = "0" ckpt_path=r"models/CDDFuse_IVF.pth" -for dataset_name in ["TNO","RoadScene"]: +for dataset_name in ["TNO"]: print("\n"*2+"="*80) model_name="CDDFuse " print("The test result of "+dataset_name+' :') - test_folder=os.path.join('test_img',dataset_name) + test_folder=os.path.join('test_img',dataset_name) test_out_folder=os.path.join('test_result',dataset_name) device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -36,9 +38,12 @@ for dataset_name in ["TNO","RoadScene"]: with torch.no_grad(): for img_name in os.listdir(os.path.join(test_folder,"ir")): + print("Processing: "+img_name) data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0 data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0 + data_VIS_BGR = cv2.imread(os.path.join(test_folder,"vi",img_name)) + _, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb)) data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS) data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() @@ -49,11 +54,18 @@ for dataset_name in ["TNO","RoadScene"]: feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D) data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D) data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse)) + # fi = np.squeeze((data_Fuse * 255).cpu().numpy()) + # img_save(fi, img_name.split(sep='.')[0], test_out_folder) + fi = np.squeeze((data_Fuse * 255).cpu().numpy()) - img_save(fi, img_name.split(sep='.')[0], test_out_folder) + fi = fi.astype(np.uint8) + ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb)) + rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB) + img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder) + print("save path : "+os.path.join(test_out_folder,img_name.split(sep='.')[0]+".png")) - eval_folder=test_out_folder + eval_folder=test_out_folder ori_img_folder=test_folder metric_result = np.zeros((8)) @@ -77,4 +89,4 @@ for dataset_name in ["TNO","RoadScene"]: +str(np.round(metric_result[6], 2))+'\t' +str(np.round(metric_result[7], 2)) ) - print("="*80) \ No newline at end of file + print("="*80) diff --git a/train.py b/train.py index 9ca0938..34e91c1 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ Import packages from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from utils.dataset import H5Dataset import os -os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import sys import time import datetime @@ -33,8 +33,8 @@ criteria_fusion = Fusionloss() model_str = 'CDDFuse' # . Set the hyper-parameters for training -num_epochs = 120 # total epoch -epoch_gap = 40 # epoches of Phase I +num_epochs = 10 # total epoch +epoch_gap = 40 # epoches of Phase I lr = 1e-4 weight_decay = 0 @@ -73,7 +73,7 @@ scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, g scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma) scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma) -MSELoss = nn.MSELoss() +MSELoss = nn.MSELoss() L1Loss = nn.L1Loss() Loss_ssim = kornia.losses.SSIM(11, reduction='mean') @@ -130,7 +130,7 @@ for epoch in range(num_epochs): Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS), kornia.filters.SpatialGradient()(data_VIS_hat)) - loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) + loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \ mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss @@ -140,24 +140,24 @@ for epoch in range(num_epochs): DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) - optimizer1.step() + optimizer1.step() optimizer2.step() else: #Phase II feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS) feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR) feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B) feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D) - data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) + data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) + - mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse) mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse) cc_loss_B = cc(feature_V_B, feature_I_B) cc_loss_D = cc(feature_V_D, feature_I_D) - loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B) + loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B) fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse) - + loss = fusionloss + coeff_decomp * loss_decomp loss.backward() nn.utils.clip_grad_norm_( @@ -168,7 +168,7 @@ for epoch in range(num_epochs): BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) nn.utils.clip_grad_norm_( DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) - optimizer1.step() + optimizer1.step() optimizer2.step() optimizer3.step() optimizer4.step() @@ -192,7 +192,7 @@ for epoch in range(num_epochs): # adjust the learning rate - scheduler1.step() + scheduler1.step() scheduler2.step() if not epoch < epoch_gap: scheduler3.step() @@ -206,7 +206,7 @@ for epoch in range(num_epochs): optimizer3.param_groups[0]['lr'] = 1e-6 if optimizer4.param_groups[0]['lr'] <= 1e-6: optimizer4.param_groups[0]['lr'] = 1e-6 - + if True: checkpoint = { 'DIDF_Encoder': DIDF_Encoder.state_dict(),