调整Python环境和代码以支持新功能
- 修改Python版本为3.8.10,更新相关依赖 - 添加新的数据处理脚本dataprocessing_sar.py - 调整网络结构描述,增加注释 - 修改测试脚本以支持新功能 -调整训练脚本中的损失计算和学习率策略
This commit is contained in:
parent
82acfa83dc
commit
fe47313277
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (6)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (7)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="star@192.168.50.108:22 password">
|
<paths name="star@192.168.50.108:22 password">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
@ -12,7 +12,7 @@
|
|||||||
<paths name="star@192.168.50.108:22 password (2)">
|
<paths name="star@192.168.50.108:22 password (2)">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" web="" />
|
||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
@ -44,6 +44,20 @@
|
|||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (7)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (8)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
<paths name="v100">
|
<paths name="v100">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" />
|
||||||
|
</component>
|
||||||
<component name="MavenImportPreferences">
|
<component name="MavenImportPreferences">
|
||||||
<option name="generalSettings">
|
<option name="generalSettings">
|
||||||
<MavenGeneralSettings>
|
<MavenGeneralSettings>
|
||||||
@ -9,5 +12,5 @@
|
|||||||
</MavenGeneralSettings>
|
</MavenGeneralSettings>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
0
data/MSRS_train_imgsize_128_stride_200.h5
Normal file
0
data/MSRS_train_imgsize_128_stride_200.h5
Normal file
@ -12,7 +12,7 @@ def get_img_file(file_name):
|
|||||||
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||||
imagelist.append(os.path.join(parent, filename))
|
imagelist.append(os.path.join(parent, filename))
|
||||||
return imagelist
|
return imagelist
|
||||||
|
|
||||||
def rgb2y(img):
|
def rgb2y(img):
|
||||||
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||||
return y
|
return y
|
||||||
@ -43,12 +43,12 @@ data_name="MSRS_train"
|
|||||||
img_size=128 #patch size
|
img_size=128 #patch size
|
||||||
stride=200 #patch stride
|
stride=200 #patch stride
|
||||||
|
|
||||||
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
|
||||||
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
|
||||||
|
|
||||||
assert len(IR_files) == len(VIS_files)
|
assert len(IR_files) == len(VIS_files)
|
||||||
h5f = h5py.File(os.path.join('.\\data',
|
h5f = h5py.File(os.path.join('./data',
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||||
'w')
|
'w')
|
||||||
h5_ir = h5f.create_group('ir_patchs')
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
h5_vis = h5f.create_group('vis_patchs')
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
@ -57,11 +57,11 @@ for i in tqdm(range(len(IR_files))):
|
|||||||
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||||
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||||
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||||
|
|
||||||
# crop
|
# crop
|
||||||
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||||
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||||
|
|
||||||
for ii in range(I_IR_Patch_Group.shape[-1]):
|
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||||
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||||
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||||
@ -72,22 +72,22 @@ for i in tqdm(range(len(IR_files))):
|
|||||||
avl_IR=avl_IR[None,...]
|
avl_IR=avl_IR[None,...]
|
||||||
avl_VIS=avl_VIS[None,...]
|
avl_VIS=avl_VIS[None,...]
|
||||||
|
|
||||||
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||||
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||||
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||||
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||||
train_num += 1
|
train_num += 1
|
||||||
|
|
||||||
h5f.close()
|
h5f.close()
|
||||||
|
|
||||||
with h5py.File(os.path.join('data',
|
with h5py.File(os.path.join('data',
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
print(f[key], key, f[key].name)
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
93
dataprocessing_sar.py
Normal file
93
dataprocessing_sar.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from skimage.io import imread
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_file(file_name):
|
||||||
|
imagelist = []
|
||||||
|
for parent, dirnames, filenames in os.walk(file_name):
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||||
|
imagelist.append(os.path.join(parent, filename))
|
||||||
|
return imagelist
|
||||||
|
|
||||||
|
def rgb2y(img):
|
||||||
|
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||||
|
return y
|
||||||
|
|
||||||
|
def Im2Patch(img, win, stride=1):
|
||||||
|
k = 0
|
||||||
|
endc = img.shape[0]
|
||||||
|
endw = img.shape[1]
|
||||||
|
endh = img.shape[2]
|
||||||
|
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||||
|
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||||
|
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||||
|
for i in range(win):
|
||||||
|
for j in range(win):
|
||||||
|
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||||
|
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||||
|
k = k + 1
|
||||||
|
return Y.reshape([endc, win, win, TotalPatNum])
|
||||||
|
|
||||||
|
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||||
|
upper_percentile=90):
|
||||||
|
"""Determine if an image is low contrast."""
|
||||||
|
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||||
|
ratio = (limits[1] - limits[0]) / limits[1]
|
||||||
|
return ratio < fraction_threshold
|
||||||
|
|
||||||
|
data_name="MSRS_train"
|
||||||
|
img_size=128 #patch size
|
||||||
|
stride=200 #patch stride
|
||||||
|
|
||||||
|
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
|
||||||
|
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
|
||||||
|
|
||||||
|
assert len(IR_files) == len(VIS_files)
|
||||||
|
h5f = h5py.File(os.path.join('./data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||||
|
'w')
|
||||||
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
|
train_num=0
|
||||||
|
for i in tqdm(range(len(IR_files))):
|
||||||
|
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||||
|
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||||
|
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||||
|
|
||||||
|
# crop
|
||||||
|
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||||
|
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||||
|
|
||||||
|
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||||
|
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||||
|
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||||
|
# Determine if the contrast is low
|
||||||
|
if not (bad_IR or bad_VIS):
|
||||||
|
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||||
|
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||||
|
avl_IR=avl_IR[None,...]
|
||||||
|
avl_VIS=avl_VIS[None,...]
|
||||||
|
|
||||||
|
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||||
|
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||||
|
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||||
|
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||||
|
train_num += 1
|
||||||
|
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
with h5py.File(os.path.join('data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
36
net.py
36
net.py
@ -41,8 +41,16 @@ class DropPath(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AttentionBase(nn.Module):
|
class AttentionBase(nn.Module):
|
||||||
|
"""
|
||||||
|
一个基础的多头注意力机制类。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
dim (int): 输入和输出的特征维度。
|
||||||
|
num_heads (int, 可选): 注意力头的数量,默认为8。
|
||||||
|
qkv_bias (bool, 可选): 是否为QKV投影层添加偏差,默认为False。
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim,
|
dim,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
qkv_bias=False,):
|
qkv_bias=False,):
|
||||||
super(AttentionBase, self).__init__()
|
super(AttentionBase, self).__init__()
|
||||||
@ -54,6 +62,15 @@ class AttentionBase(nn.Module):
|
|||||||
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
|
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
定义了输入数据x通过多头注意力机制的前向传播过程。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
x (Tensor): 输入的特征张量,形状为[batch_size, dim, height, width]。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor: 输出的特征张量,形状为[batch_size, dim, height, width]。
|
||||||
|
"""
|
||||||
# [batch_size, num_patches + 1, total_embed_dim]
|
# [batch_size, num_patches + 1, total_embed_dim]
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
qkv = self.qkv2(self.qkv1(x))
|
qkv = self.qkv2(self.qkv1(x))
|
||||||
@ -78,14 +95,15 @@ class AttentionBase(nn.Module):
|
|||||||
|
|
||||||
out = self.proj(out)
|
out = self.proj(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
"""
|
"""
|
||||||
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features=None,
|
hidden_features=None,
|
||||||
ffn_expansion_factor = 2,
|
ffn_expansion_factor = 2,
|
||||||
bias = False):
|
bias = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -110,7 +128,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim,
|
dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
ffn_expansion_factor=1.,
|
ffn_expansion_factor=1.,
|
||||||
qkv_bias=False,):
|
qkv_bias=False,):
|
||||||
super(BaseFeatureExtraction, self).__init__()
|
super(BaseFeatureExtraction, self).__init__()
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
@ -353,7 +371,7 @@ class Restormer_Encoder(nn.Module):
|
|||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
|
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
|
||||||
self.detailFeature = DetailFeatureExtraction()
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
def forward(self, inp_img):
|
def forward(self, inp_img):
|
||||||
inp_enc_level1 = self.patch_embed(inp_img)
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
@ -383,7 +401,7 @@ class Restormer_Decoder(nn.Module):
|
|||||||
nn.LeakyReLU(),
|
nn.LeakyReLU(),
|
||||||
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
|
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
|
||||||
stride=1, padding=1, bias=bias),)
|
stride=1, padding=1, bias=bias),)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
def forward(self, inp_img, base_feature, detail_feature):
|
def forward(self, inp_img, base_feature, detail_feature):
|
||||||
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
||||||
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
||||||
@ -393,7 +411,7 @@ class Restormer_Decoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
out_enc_level1 = self.output(out_enc_level1)
|
out_enc_level1 = self.output(out_enc_level1)
|
||||||
return self.sigmoid(out_enc_level1), out_enc_level0
|
return self.sigmoid(out_enc_level1), out_enc_level0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
height = 128
|
height = 128
|
||||||
width = 128
|
width = 128
|
||||||
|
22
test_IVF.py
22
test_IVF.py
@ -1,3 +1,5 @@
|
|||||||
|
import cv2
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,11 +14,11 @@ logging.basicConfig(level=logging.CRITICAL)
|
|||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
ckpt_path=r"models/CDDFuse_IVF.pth"
|
ckpt_path=r"models/CDDFuse_IVF.pth"
|
||||||
for dataset_name in ["TNO","RoadScene"]:
|
for dataset_name in ["TNO"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
model_name="CDDFuse "
|
model_name="CDDFuse "
|
||||||
print("The test result of "+dataset_name+' :')
|
print("The test result of "+dataset_name+' :')
|
||||||
test_folder=os.path.join('test_img',dataset_name)
|
test_folder=os.path.join('test_img',dataset_name)
|
||||||
test_out_folder=os.path.join('test_result',dataset_name)
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
@ -36,9 +38,12 @@ for dataset_name in ["TNO","RoadScene"]:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
print("Processing: "+img_name)
|
||||||
|
|
||||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
data_VIS_BGR = cv2.imread(os.path.join(test_folder,"vi",img_name))
|
||||||
|
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
||||||
|
|
||||||
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
@ -49,11 +54,18 @@ for dataset_name in ["TNO","RoadScene"]:
|
|||||||
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||||
|
# fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
|
# img_save(fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
|
|
||||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
|
fi = fi.astype(np.uint8)
|
||||||
|
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
||||||
|
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||||
|
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
|
print("save path : "+os.path.join(test_out_folder,img_name.split(sep='.')[0]+".png"))
|
||||||
|
|
||||||
|
|
||||||
eval_folder=test_out_folder
|
eval_folder=test_out_folder
|
||||||
ori_img_folder=test_folder
|
ori_img_folder=test_folder
|
||||||
|
|
||||||
metric_result = np.zeros((8))
|
metric_result = np.zeros((8))
|
||||||
@ -77,4 +89,4 @@ for dataset_name in ["TNO","RoadScene"]:
|
|||||||
+str(np.round(metric_result[6], 2))+'\t'
|
+str(np.round(metric_result[6], 2))+'\t'
|
||||||
+str(np.round(metric_result[7], 2))
|
+str(np.round(metric_result[7], 2))
|
||||||
)
|
)
|
||||||
print("="*80)
|
print("="*80)
|
||||||
|
26
train.py
26
train.py
@ -9,7 +9,7 @@ Import packages
|
|||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
from utils.dataset import H5Dataset
|
from utils.dataset import H5Dataset
|
||||||
import os
|
import os
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
@ -33,8 +33,8 @@ criteria_fusion = Fusionloss()
|
|||||||
model_str = 'CDDFuse'
|
model_str = 'CDDFuse'
|
||||||
|
|
||||||
# . Set the hyper-parameters for training
|
# . Set the hyper-parameters for training
|
||||||
num_epochs = 120 # total epoch
|
num_epochs = 10 # total epoch
|
||||||
epoch_gap = 40 # epoches of Phase I
|
epoch_gap = 40 # epoches of Phase I
|
||||||
|
|
||||||
lr = 1e-4
|
lr = 1e-4
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
@ -73,7 +73,7 @@ scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, g
|
|||||||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
|
||||||
MSELoss = nn.MSELoss()
|
MSELoss = nn.MSELoss()
|
||||||
L1Loss = nn.L1Loss()
|
L1Loss = nn.L1Loss()
|
||||||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ for epoch in range(num_epochs):
|
|||||||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||||
|
|
||||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||||
|
|
||||||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
||||||
@ -140,24 +140,24 @@ for epoch in range(num_epochs):
|
|||||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
optimizer1.step()
|
optimizer1.step()
|
||||||
optimizer2.step()
|
optimizer2.step()
|
||||||
else: #Phase II
|
else: #Phase II
|
||||||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
||||||
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
||||||
|
|
||||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||||
|
|
||||||
loss = fusionloss + coeff_decomp * loss_decomp
|
loss = fusionloss + coeff_decomp * loss_decomp
|
||||||
loss.backward()
|
loss.backward()
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
@ -168,7 +168,7 @@ for epoch in range(num_epochs):
|
|||||||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
optimizer1.step()
|
optimizer1.step()
|
||||||
optimizer2.step()
|
optimizer2.step()
|
||||||
optimizer3.step()
|
optimizer3.step()
|
||||||
optimizer4.step()
|
optimizer4.step()
|
||||||
@ -192,7 +192,7 @@ for epoch in range(num_epochs):
|
|||||||
|
|
||||||
# adjust the learning rate
|
# adjust the learning rate
|
||||||
|
|
||||||
scheduler1.step()
|
scheduler1.step()
|
||||||
scheduler2.step()
|
scheduler2.step()
|
||||||
if not epoch < epoch_gap:
|
if not epoch < epoch_gap:
|
||||||
scheduler3.step()
|
scheduler3.step()
|
||||||
@ -206,7 +206,7 @@ for epoch in range(num_epochs):
|
|||||||
optimizer3.param_groups[0]['lr'] = 1e-6
|
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||||
optimizer4.param_groups[0]['lr'] = 1e-6
|
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||||
|
Loading…
Reference in New Issue
Block a user