pfcfuse/test_sar.py
zjut 0e6064181a 新增SAR图像处理功能并优化模型性能
- 新增BaseFeatureExtractionSAR和DetailFeatureExtractionSAR类,专门用于SAR图像的特征提取
- 在Restormer_Encoder中加入SAR图像处理的支持,通过新增的SAR特征提取模块提高模型对SAR图像的处理能力
- 更新test_IVF.py,增加对SAR图像的测试,验证模型在不同数据集上的性能
- 通过这些修改,模型在TNO和RoadScene数据集上的表现得到显著提升,详细指标见日志文件
2024-10-09 12:09:30 +08:00

202 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import sys
import uuid
import cv2
from PIL import Image
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save, image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
path = os.path.dirname(sys.argv[0]) + "\\"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ckpt_path=r"models/CDDFuse_IVF.pth"
ckpt_path = r"" + path + "models/CDDFuse_IVF.pth"
print(torch.cuda.is_available())
def main(opt):
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
sar_path = transfer(opt.sarPath, 100, 0.15)
vi_path = transfer(opt.viPath, 100, 0.15)
output_path = opt.outputPath
print("\n" * 2 + "=" * 80)
print("The sar_path of " + sar_path + ' :')
print("The vi_path of " + vi_path + ' :')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
data_IR = image_read_cv2(sar_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS_BGR = cv2.imread(vi_path)
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
fi = fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
# 获取文件名(包含后缀)
file_name_with_extension = os.path.basename(sar_path)
# 分离文件名和文件后缀
file_name, file_extension = os.path.splitext(file_name_with_extension)
img_save(rgb_fi, "fusionSAR_" + file_name, output_path)
print("输出文件路径:" + output_path + "fusionSAR_" + file_name + ".jpg")
# metric_result = np.zeros((8))
# sarImagePath = sar_path
# ir = image_read_cv2(sarImagePath, 'GRAY')
# viImagePath = vi_path
# vi = image_read_cv2(viImagePath, 'GRAY')
#
# fusionImagePath = os.path.join(output_path, "fusionSAR_{}.jpg".format(file_name))
# fi = image_read_cv2(fusionImagePath, 'GRAY')
# # 统计
# metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
# , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
# , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
# , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
#
# metric_result /= len(os.listdir(output_path))
# print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
# print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
# + str(np.round(metric_result[1], 2)) + '\t'
# + str(np.round(metric_result[2], 2)) + '\t'
# + str(np.round(metric_result[3], 2)) + '\t'
# + str(np.round(metric_result[4], 2)) + '\t'
# + str(np.round(metric_result[5], 2)) + '\t'
# + str(np.round(metric_result[6], 2)) + '\t'
# + str(np.round(metric_result[7], 2))
# )
# print("=" * 80)
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path
def parse_opt():
parser = argparse.ArgumentParser(
description='python.exe --sarPath "sar绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
parser.add_argument('--sarPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\ir\\NH49E011024.tif", required=True,
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\vi\\NH49E011024.tif", required=True,
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--outputPath', type=str, default='results_detect', required=True,
help="输出路径!") # 使用的也是图片的目标地址
opt = parser.parse_args()
return opt
if __name__ == '__main__':
print(torch.cuda.is_available())
opt = parse_opt()
main(opt)
def add_prefix_to_files(directory_path, prefix):
# 使用os.listdir获取目录中的所有文件
files = os.listdir(directory_path)
for old_filename in files:
# 构建新的文件名
new_filename = f"{prefix}_{old_filename}"
# 构建旧文件路径和新文件路径
old_path = os.path.join(directory_path, old_filename)
new_path = os.path.join(directory_path, new_filename)
# 使用os.rename进行文件重命名
os.rename(old_path, new_path)
print(f'{old_filename} 重命名为 {new_filename}')
# 替换为实际的目录路径和前缀
# directory_path = '/path/to/your/directory'
# new_prefix = 'new'
#
# # 执行批量重命名
# add_prefix_to_files(directory_path, new_prefix)