diff --git a/.idea/CDDFuse.iml b/.idea/CDDFuse.iml
index 925622b..b6bdc52 100644
--- a/.idea/CDDFuse.iml
+++ b/.idea/CDDFuse.iml
@@ -2,7 +2,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
index a1c367f..c87715f 100644
--- a/.idea/deployment.xml
+++ b/.idea/deployment.xml
@@ -1,6 +1,6 @@
-
+
@@ -12,7 +12,7 @@
-
+
@@ -44,6 +44,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 21ebe7b..91adbbe 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,5 +1,8 @@
+
+
+
-
+
\ No newline at end of file
diff --git a/data/MSRS_train_imgsize_128_stride_200.h5 b/data/MSRS_train_imgsize_128_stride_200.h5
new file mode 100644
index 0000000..e69de29
diff --git a/dataprocessing.py b/dataprocessing.py
index 1b986f0..3b2d725 100644
--- a/dataprocessing.py
+++ b/dataprocessing.py
@@ -12,7 +12,7 @@ def get_img_file(file_name):
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
imagelist.append(os.path.join(parent, filename))
return imagelist
-
+
def rgb2y(img):
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
return y
@@ -43,12 +43,12 @@ data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
-IR_files = sorted(get_img_file(r"MSRS_train/ir"))
-VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
+IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
+VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
assert len(IR_files) == len(VIS_files)
-h5f = h5py.File(os.path.join('.\\data',
- data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
+h5f = h5py.File(os.path.join('./data',
+ data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
@@ -57,11 +57,11 @@ for i in tqdm(range(len(IR_files))):
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
-
- # crop
+
+ # crop
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
-
+
for ii in range(I_IR_Patch_Group.shape[-1]):
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
@@ -72,22 +72,22 @@ for i in tqdm(range(len(IR_files))):
avl_IR=avl_IR[None,...]
avl_VIS=avl_VIS[None,...]
- h5_ir.create_dataset(str(train_num), data=avl_IR,
+ h5_ir.create_dataset(str(train_num), data=avl_IR,
dtype=avl_IR.dtype, shape=avl_IR.shape)
- h5_vis.create_dataset(str(train_num), data=avl_VIS,
+ h5_vis.create_dataset(str(train_num), data=avl_VIS,
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
- train_num += 1
+ train_num += 1
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
for key in f.keys():
- print(f[key], key, f[key].name)
-
-
-
+ print(f[key], key, f[key].name)
+
+
+
+
-
diff --git a/dataprocessing_sar.py b/dataprocessing_sar.py
new file mode 100644
index 0000000..3b2d725
--- /dev/null
+++ b/dataprocessing_sar.py
@@ -0,0 +1,93 @@
+import os
+import h5py
+import numpy as np
+from tqdm import tqdm
+from skimage.io import imread
+
+
+def get_img_file(file_name):
+ imagelist = []
+ for parent, dirnames, filenames in os.walk(file_name):
+ for filename in filenames:
+ if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
+ imagelist.append(os.path.join(parent, filename))
+ return imagelist
+
+def rgb2y(img):
+ y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
+ return y
+
+def Im2Patch(img, win, stride=1):
+ k = 0
+ endc = img.shape[0]
+ endw = img.shape[1]
+ endh = img.shape[2]
+ patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
+ TotalPatNum = patch.shape[1] * patch.shape[2]
+ Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
+ for i in range(win):
+ for j in range(win):
+ patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
+ Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
+ k = k + 1
+ return Y.reshape([endc, win, win, TotalPatNum])
+
+def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
+ upper_percentile=90):
+ """Determine if an image is low contrast."""
+ limits = np.percentile(image, [lower_percentile, upper_percentile])
+ ratio = (limits[1] - limits[0]) / limits[1]
+ return ratio < fraction_threshold
+
+data_name="MSRS_train"
+img_size=128 #patch size
+stride=200 #patch stride
+
+IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
+VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
+
+assert len(IR_files) == len(VIS_files)
+h5f = h5py.File(os.path.join('./data',
+ data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
+ 'w')
+h5_ir = h5f.create_group('ir_patchs')
+h5_vis = h5f.create_group('vis_patchs')
+train_num=0
+for i in tqdm(range(len(IR_files))):
+ I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
+ I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
+ I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
+
+ # crop
+ I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
+ I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
+
+ for ii in range(I_IR_Patch_Group.shape[-1]):
+ bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
+ bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
+ # Determine if the contrast is low
+ if not (bad_IR or bad_VIS):
+ avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
+ avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
+ avl_IR=avl_IR[None,...]
+ avl_VIS=avl_VIS[None,...]
+
+ h5_ir.create_dataset(str(train_num), data=avl_IR,
+ dtype=avl_IR.dtype, shape=avl_IR.shape)
+ h5_vis.create_dataset(str(train_num), data=avl_VIS,
+ dtype=avl_VIS.dtype, shape=avl_VIS.shape)
+ train_num += 1
+
+h5f.close()
+
+with h5py.File(os.path.join('data',
+ data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
+ for key in f.keys():
+ print(f[key], key, f[key].name)
+
+
+
+
+
+
+
diff --git a/net.py b/net.py
index 8353a2c..151fb82 100644
--- a/net.py
+++ b/net.py
@@ -41,8 +41,16 @@ class DropPath(nn.Module):
class AttentionBase(nn.Module):
+ """
+ 一个基础的多头注意力机制类。
+
+ 参数:
+ dim (int): 输入和输出的特征维度。
+ num_heads (int, 可选): 注意力头的数量,默认为8。
+ qkv_bias (bool, 可选): 是否为QKV投影层添加偏差,默认为False。
+ """
def __init__(self,
- dim,
+ dim,
num_heads=8,
qkv_bias=False,):
super(AttentionBase, self).__init__()
@@ -54,6 +62,15 @@ class AttentionBase(nn.Module):
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
def forward(self, x):
+ """
+ 定义了输入数据x通过多头注意力机制的前向传播过程。
+
+ 参数:
+ x (Tensor): 输入的特征张量,形状为[batch_size, dim, height, width]。
+
+ 返回:
+ Tensor: 输出的特征张量,形状为[batch_size, dim, height, width]。
+ """
# [batch_size, num_patches + 1, total_embed_dim]
b, c, h, w = x.shape
qkv = self.qkv2(self.qkv1(x))
@@ -78,14 +95,15 @@ class AttentionBase(nn.Module):
out = self.proj(out)
return out
-
+
+
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
- def __init__(self,
- in_features,
- hidden_features=None,
+ def __init__(self,
+ in_features,
+ hidden_features=None,
ffn_expansion_factor = 2,
bias = False):
super().__init__()
@@ -110,7 +128,7 @@ class BaseFeatureExtraction(nn.Module):
def __init__(self,
dim,
num_heads,
- ffn_expansion_factor=1.,
+ ffn_expansion_factor=1.,
qkv_bias=False,):
super(BaseFeatureExtraction, self).__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
@@ -353,7 +371,7 @@ class Restormer_Encoder(nn.Module):
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
self.detailFeature = DetailFeatureExtraction()
-
+
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
@@ -383,7 +401,7 @@ class Restormer_Decoder(nn.Module):
nn.LeakyReLU(),
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias),)
- self.sigmoid = nn.Sigmoid()
+ self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
@@ -393,7 +411,7 @@ class Restormer_Decoder(nn.Module):
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
-
+
if __name__ == '__main__':
height = 128
width = 128
diff --git a/test_IVF.py b/test_IVF.py
index f17df6f..9041e85 100644
--- a/test_IVF.py
+++ b/test_IVF.py
@@ -1,3 +1,5 @@
+import cv2
+
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
@@ -12,11 +14,11 @@ logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path=r"models/CDDFuse_IVF.pth"
-for dataset_name in ["TNO","RoadScene"]:
+for dataset_name in ["TNO"]:
print("\n"*2+"="*80)
model_name="CDDFuse "
print("The test result of "+dataset_name+' :')
- test_folder=os.path.join('test_img',dataset_name)
+ test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -36,9 +38,12 @@ for dataset_name in ["TNO","RoadScene"]:
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")):
+ print("Processing: "+img_name)
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
+ data_VIS_BGR = cv2.imread(os.path.join(test_folder,"vi",img_name))
+ _, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
@@ -49,11 +54,18 @@ for dataset_name in ["TNO","RoadScene"]:
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
+ # fi = np.squeeze((data_Fuse * 255).cpu().numpy())
+ # img_save(fi, img_name.split(sep='.')[0], test_out_folder)
+
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
- img_save(fi, img_name.split(sep='.')[0], test_out_folder)
+ fi = fi.astype(np.uint8)
+ ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
+ rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
+ img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
+ print("save path : "+os.path.join(test_out_folder,img_name.split(sep='.')[0]+".png"))
- eval_folder=test_out_folder
+ eval_folder=test_out_folder
ori_img_folder=test_folder
metric_result = np.zeros((8))
@@ -77,4 +89,4 @@ for dataset_name in ["TNO","RoadScene"]:
+str(np.round(metric_result[6], 2))+'\t'
+str(np.round(metric_result[7], 2))
)
- print("="*80)
\ No newline at end of file
+ print("="*80)
diff --git a/train.py b/train.py
index 9ca0938..34e91c1 100644
--- a/train.py
+++ b/train.py
@@ -9,7 +9,7 @@ Import packages
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
from utils.dataset import H5Dataset
import os
-os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import datetime
@@ -33,8 +33,8 @@ criteria_fusion = Fusionloss()
model_str = 'CDDFuse'
# . Set the hyper-parameters for training
-num_epochs = 120 # total epoch
-epoch_gap = 40 # epoches of Phase I
+num_epochs = 10 # total epoch
+epoch_gap = 40 # epoches of Phase I
lr = 1e-4
weight_decay = 0
@@ -73,7 +73,7 @@ scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, g
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
-MSELoss = nn.MSELoss()
+MSELoss = nn.MSELoss()
L1Loss = nn.L1Loss()
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
@@ -130,7 +130,7 @@ for epoch in range(num_epochs):
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
kornia.filters.SpatialGradient()(data_VIS_hat))
- loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
+ loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
@@ -140,24 +140,24 @@ for epoch in range(num_epochs):
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
- optimizer1.step()
+ optimizer1.step()
optimizer2.step()
else: #Phase II
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
- data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
+ data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
+
-
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
- loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
+ loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
-
+
loss = fusionloss + coeff_decomp * loss_decomp
loss.backward()
nn.utils.clip_grad_norm_(
@@ -168,7 +168,7 @@ for epoch in range(num_epochs):
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
- optimizer1.step()
+ optimizer1.step()
optimizer2.step()
optimizer3.step()
optimizer4.step()
@@ -192,7 +192,7 @@ for epoch in range(num_epochs):
# adjust the learning rate
- scheduler1.step()
+ scheduler1.step()
scheduler2.step()
if not epoch < epoch_gap:
scheduler3.step()
@@ -206,7 +206,7 @@ for epoch in range(num_epochs):
optimizer3.param_groups[0]['lr'] = 1e-6
if optimizer4.param_groups[0]['lr'] <= 1e-6:
optimizer4.param_groups[0]['lr'] = 1e-6
-
+
if True:
checkpoint = {
'DIDF_Encoder': DIDF_Encoder.state_dict(),