pfcfuse/test_sar.py
whaifree afd55abe9e 模型结构
DetailFeatureExtraction增加了一个增强残差
BaseFeatureExtraction增加了
x = self.WTConv2d(x)
2024-10-07 15:24:33 +08:00

202 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import sys
import uuid
import cv2
from PIL import Image
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save, image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
path = os.path.dirname(sys.argv[0]) + "\\"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ckpt_path=r"models/CDDFuse_IVF.pth"
ckpt_path = r"" + path + "models/CDDFuse_IVF.pth"
print(torch.cuda.is_available())
def main(opt):
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
sar_path = transfer(opt.sarPath, 100, 0.15)
vi_path = transfer(opt.viPath, 100, 0.15)
output_path = opt.outputPath
print("\n" * 2 + "=" * 80)
print("The sar_path of " + sar_path + ' :')
print("The vi_path of " + vi_path + ' :')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
data_IR = image_read_cv2(sar_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS_BGR = cv2.imread(vi_path)
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
fi = fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
# 获取文件名(包含后缀)
file_name_with_extension = os.path.basename(sar_path)
# 分离文件名和文件后缀
file_name, file_extension = os.path.splitext(file_name_with_extension)
img_save(rgb_fi, "fusionSAR_" + file_name, output_path)
print("输出文件路径:" + output_path + "fusionSAR_" + file_name + ".jpg")
# metric_result = np.zeros((8))
# sarImagePath = sar_path
# ir = image_read_cv2(sarImagePath, 'GRAY')
# viImagePath = vi_path
# vi = image_read_cv2(viImagePath, 'GRAY')
#
# fusionImagePath = os.path.join(output_path, "fusionSAR_{}.jpg".format(file_name))
# fi = image_read_cv2(fusionImagePath, 'GRAY')
# # 统计
# metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
# , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
# , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
# , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
#
# metric_result /= len(os.listdir(output_path))
# print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
# print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
# + str(np.round(metric_result[1], 2)) + '\t'
# + str(np.round(metric_result[2], 2)) + '\t'
# + str(np.round(metric_result[3], 2)) + '\t'
# + str(np.round(metric_result[4], 2)) + '\t'
# + str(np.round(metric_result[5], 2)) + '\t'
# + str(np.round(metric_result[6], 2)) + '\t'
# + str(np.round(metric_result[7], 2))
# )
# print("=" * 80)
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path
def parse_opt():
parser = argparse.ArgumentParser(
description='python.exe --sarPath "sar绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
parser.add_argument('--sarPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\ir\\NH49E011024.tif", required=True,
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\vi\\NH49E011024.tif", required=True,
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--outputPath', type=str, default='results_detect', required=True,
help="输出路径!") # 使用的也是图片的目标地址
opt = parser.parse_args()
return opt
if __name__ == '__main__':
print(torch.cuda.is_available())
opt = parse_opt()
main(opt)
def add_prefix_to_files(directory_path, prefix):
# 使用os.listdir获取目录中的所有文件
files = os.listdir(directory_path)
for old_filename in files:
# 构建新的文件名
new_filename = f"{prefix}_{old_filename}"
# 构建旧文件路径和新文件路径
old_path = os.path.join(directory_path, old_filename)
new_path = os.path.join(directory_path, new_filename)
# 使用os.rename进行文件重命名
os.rename(old_path, new_path)
print(f'{old_filename} 重命名为 {new_filename}')
# 替换为实际的目录路径和前缀
# directory_path = '/path/to/your/directory'
# new_prefix = 'new'
#
# # 执行批量重命名
# add_prefix_to_files(directory_path, new_prefix)