d
This commit is contained in:
parent
936e5d4ef1
commit
d32424f54a
132
test.py
Normal file
132
test.py
Normal file
@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from matplotlib import image as mpimg, pyplot as plt
|
||||
|
||||
from net import Sar_Restormer_Encoder,Vi_Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
import os
|
||||
import numpy as np
|
||||
from utils.Evaluator import Evaluator
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.img_read_save import img_save, image_read_cv2
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logging.basicConfig(level=logging.CRITICAL)
|
||||
|
||||
path = os.path.dirname(sys.argv[0]) + "\\"
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# ckpt_path=r"models/CDDFuse_IVF.pth"
|
||||
ckpt_path = r"" + path + "models/CDDFuse_04-10-11-56.pth"
|
||||
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
|
||||
def main(opt):
|
||||
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
|
||||
|
||||
ir_path = opt.irPath
|
||||
vi_path = opt.viPath
|
||||
output_path = opt.outputPath
|
||||
|
||||
print("\n" * 2 + "=" * 80)
|
||||
|
||||
model_name = "CDDFuse "
|
||||
print("The ir_path of " + ir_path + ' :')
|
||||
print("The vi_path of " + vi_path + ' :')
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
SAR_Encoder = nn.DataParallel(Sar_Restormer_Encoder()).to(device)
|
||||
VI_Encoder = nn.DataParallel(Vi_Restormer_Encoder()).to(device)
|
||||
|
||||
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
SAR_Encoder.load_state_dict(torch.load(ckpt_path)['SAR_DIDF_Encoder'])
|
||||
VI_Encoder.load_state_dict(torch.load(ckpt_path)['VI_DIDF_Encoder'])
|
||||
|
||||
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||
SAR_Encoder.eval()
|
||||
VI_Encoder.eval()
|
||||
|
||||
Decoder.eval()
|
||||
BaseFuseLayer.eval()
|
||||
DetailFuseLayer.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
data_IR = image_read_cv2(ir_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||
|
||||
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
|
||||
feature_V_B, feature_V_D, feature_V = VI_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = SAR_Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
|
||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||
|
||||
# 获取文件名(包含后缀)
|
||||
file_name_with_extension = os.path.basename(ir_path)
|
||||
# 分离文件名和文件后缀
|
||||
file_name, file_extension = os.path.splitext(file_name_with_extension)
|
||||
|
||||
img_save(fi, "fusion_" + file_name, output_path)
|
||||
print("输出文件路径:" + output_path + "fusion_" + file_name + ".png")
|
||||
|
||||
metric_result = np.zeros((8))
|
||||
irImagePath = ir_path
|
||||
ir = image_read_cv2(irImagePath, 'GRAY')
|
||||
viImagePath = vi_path
|
||||
vi = image_read_cv2(viImagePath, 'GRAY')
|
||||
|
||||
fusionImagePath = os.path.join(output_path, "fusion_{}.png".format(file_name))
|
||||
fi = image_read_cv2(fusionImagePath, 'GRAY')
|
||||
# 统计
|
||||
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||
|
||||
metric_result /= len(os.listdir(output_path))
|
||||
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||
print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
|
||||
+ str(np.round(metric_result[1], 2)) + '\t'
|
||||
+ str(np.round(metric_result[2], 2)) + '\t'
|
||||
+ str(np.round(metric_result[3], 2)) + '\t'
|
||||
+ str(np.round(metric_result[4], 2)) + '\t'
|
||||
+ str(np.round(metric_result[5], 2)) + '\t'
|
||||
+ str(np.round(metric_result[6], 2)) + '\t'
|
||||
+ str(np.round(metric_result[7], 2))
|
||||
)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='python.exe --irPath "红外绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
|
||||
parser.add_argument('--irPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\sar\\NH49E001013_10.tif", required=False,
|
||||
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\opr\\NH49E001013_10.tif", required=False,
|
||||
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||
parser.add_argument('--outputPath', type=str, default='results_detect', required=False,
|
||||
help="输出路径!") # 使用的也是图片的目标地址
|
||||
|
||||
opt = parser.parse_args()
|
||||
return opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = parse_opt()
|
||||
main(opt)
|
Loading…
Reference in New Issue
Block a user