调整Python环境和代码以支持新功能

- 修改Python版本为3.8.10,更新相关依赖
- 添加新的数据处理脚本dataprocessing_sar.py
- 调整网络结构描述,增加注释
- 修改测试脚本以支持新功能
-调整训练脚本中的损失计算和学习率策略
This commit is contained in:
whaifree 2024-10-06 14:42:13 +08:00
parent 82acfa83dc
commit fe47313277
9 changed files with 187 additions and 47 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (6)" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (7)" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="star@192.168.50.108:22 password">
<serverdata>
@ -12,7 +12,7 @@
<paths name="star@192.168.50.108:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" web="" />
</mappings>
</serverdata>
</paths>
@ -44,6 +44,20 @@
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (7)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (8)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="v100">
<serverdata>
<mappings>

View File

@ -1,5 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" />
</component>
<component name="MavenImportPreferences">
<option name="generalSettings">
<MavenGeneralSettings>
@ -9,5 +12,5 @@
</MavenGeneralSettings>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/cddfuse/bin/python)" project-jdk-type="Python SDK" />
</project>

View File

@ -43,11 +43,11 @@ data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('.\\data',
h5f = h5py.File(os.path.join('./data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')

93
dataprocessing_sar.py Normal file
View File

@ -0,0 +1,93 @@
import os
import h5py
import numpy as np
from tqdm import tqdm
from skimage.io import imread
def get_img_file(file_name):
imagelist = []
for parent, dirnames, filenames in os.walk(file_name):
for filename in filenames:
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
imagelist.append(os.path.join(parent, filename))
return imagelist
def rgb2y(img):
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
return y
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
upper_percentile=90):
"""Determine if an image is low contrast."""
limits = np.percentile(image, [lower_percentile, upper_percentile])
ratio = (limits[1] - limits[0]) / limits[1]
return ratio < fraction_threshold
data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/ir"))
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/MSRS-main/train/vi"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('./data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
train_num=0
for i in tqdm(range(len(IR_files))):
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
# crop
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
for ii in range(I_IR_Patch_Group.shape[-1]):
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
# Determine if the contrast is low
if not (bad_IR or bad_VIS):
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
avl_IR=avl_IR[None,...]
avl_VIS=avl_VIS[None,...]
h5_ir.create_dataset(str(train_num), data=avl_IR,
dtype=avl_IR.dtype, shape=avl_IR.shape)
h5_vis.create_dataset(str(train_num), data=avl_VIS,
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
train_num += 1
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
for key in f.keys():
print(f[key], key, f[key].name)

18
net.py
View File

@ -41,6 +41,14 @@ class DropPath(nn.Module):
class AttentionBase(nn.Module):
"""
一个基础的多头注意力机制类
参数:
dim (int): 输入和输出的特征维度
num_heads (int, 可选): 注意力头的数量默认为8
qkv_bias (bool, 可选): 是否为QKV投影层添加偏差默认为False
"""
def __init__(self,
dim,
num_heads=8,
@ -54,6 +62,15 @@ class AttentionBase(nn.Module):
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
def forward(self, x):
"""
定义了输入数据x通过多头注意力机制的前向传播过程
参数:
x (Tensor): 输入的特征张量形状为[batch_size, dim, height, width]
返回:
Tensor: 输出的特征张量形状为[batch_size, dim, height, width]
"""
# [batch_size, num_patches + 1, total_embed_dim]
b, c, h, w = x.shape
qkv = self.qkv2(self.qkv1(x))
@ -79,6 +96,7 @@ class AttentionBase(nn.Module):
out = self.proj(out)
return out
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks

View File

@ -1,3 +1,5 @@
import cv2
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
@ -12,7 +14,7 @@ logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path=r"models/CDDFuse_IVF.pth"
for dataset_name in ["TNO","RoadScene"]:
for dataset_name in ["TNO"]:
print("\n"*2+"="*80)
model_name="CDDFuse "
print("The test result of "+dataset_name+' :')
@ -36,9 +38,12 @@ for dataset_name in ["TNO","RoadScene"]:
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")):
print("Processing: "+img_name)
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS_BGR = cv2.imread(os.path.join(test_folder,"vi",img_name))
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
@ -49,8 +54,15 @@ for dataset_name in ["TNO","RoadScene"]:
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
# fi = np.squeeze((data_Fuse * 255).cpu().numpy())
# img_save(fi, img_name.split(sep='.')[0], test_out_folder)
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
fi = fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
print("save path : "+os.path.join(test_out_folder,img_name.split(sep='.')[0]+".png"))
eval_folder=test_out_folder

View File

@ -33,7 +33,7 @@ criteria_fusion = Fusionloss()
model_str = 'CDDFuse'
# . Set the hyper-parameters for training
num_epochs = 120 # total epoch
num_epochs = 10 # total epoch
epoch_gap = 40 # epoches of Phase I
lr = 1e-4